Fixed over-head bug in the auto-tuner (not in the benchmarks)

This commit is contained in:
Philippe Tillet
2014-11-06 16:14:46 -05:00
parent 8dd492de23
commit 6595d20c34
9 changed files with 110 additions and 63 deletions

View File

@@ -71,6 +71,9 @@ namespace atidlas
return res; return res;
} }
std::vector<tree> const & estimators() const
{ return estimators_; }
private: private:
std::vector<tree> estimators_; std::vector<tree> estimators_;
}; };
@@ -93,17 +96,27 @@ namespace atidlas
public: public:
model(random_forest const & predictor, std::vector< tools::shared_ptr<template_base> > const & templates, model(random_forest const & predictor, std::vector< tools::shared_ptr<template_base> > const & templates,
viennacl::ocl::context & context, viennacl::ocl::device const & device) : predictor_(predictor), templates_(templates), context_(context), device_(device) viennacl::ocl::context & context, viennacl::ocl::device const & device) : predictor_(new random_forest(predictor)), templates_(templates), context_(context), device_(device)
{ } { }
void execute(statements_container const & statements, bool bypass_predictor = false) model(std::vector< tools::shared_ptr<template_base> > const & templates, viennacl::ocl::context & context, viennacl::ocl::device const & device) :
templates_(templates), context_(context), device_(device)
{}
model(template_base const & tp, viennacl::ocl::context & context, viennacl::ocl::device const & device) :
templates_(1,tp.clone()), context_(context), device_(device)
{}
void execute(statements_container const & statements, bool bypass_predictor = false, bool force_recompilation = false)
{ {
bypass_predictor = bypass_predictor || predictor_.get()==NULL;
if(lazy_programs_.empty()) if(lazy_programs_.empty())
{ {
std::string pname = tools::statements_representation(statements, BIND_TO_HANDLE); std::string pname = tools::statements_representation(statements, BIND_TO_HANDLE);
init_program_compiler(pname, false); init_program_compiler(pname, force_recompilation);
init_program_compiler(pname + "_fb", false); init_program_compiler(pname + "_fb", force_recompilation);
for(size_t i = 0 ; i < templates_.size() ; ++i) for(size_t i = 0 ; i < templates_.size() ; ++i)
{ {
@@ -126,11 +139,10 @@ namespace atidlas
//Default //Default
else else
{ {
std::vector<float> predictions = predictor_.predict(x); std::vector<float> predictions = predictor_->predict(x);
label = std::distance(predictions.begin(),std::min_element(predictions.begin(), predictions.end())); label = std::distance(predictions.begin(),std::min_element(predictions.begin(), predictions.end()));
} }
//Execution //Execution
templates_[label]->enqueue("k" + tools::to_string(label), lazy_programs_, statements); templates_[label]->enqueue("k" + tools::to_string(label), lazy_programs_, statements);
} }
@@ -154,10 +166,9 @@ namespace atidlas
} }
private: private:
random_forest predictor_; tools::shared_ptr<random_forest> predictor_;
templates_container templates_; templates_container templates_;
std::map<std::vector<atidlas_int_t>, int> hardcoded_; std::map<std::vector<atidlas_int_t>, int> hardcoded_;
viennacl::ocl::context & context_; viennacl::ocl::context & context_;

View File

@@ -90,7 +90,7 @@ private:
public: public:
vector_axpy_template(vector_axpy_template::parameters_type const & parameters, binding_policy_t binding_policy = BIND_ALL_UNIQUE) : template_base_impl<vector_axpy_template, vector_axpy_parameters>(parameters, binding_policy), up_to_internal_size_(false) vector_axpy_template(vector_axpy_template::parameters_type const & parameters, binding_policy_t binding_policy = BIND_ALL_UNIQUE) : template_base_impl<vector_axpy_template, vector_axpy_parameters>(parameters, binding_policy), up_to_internal_size_(false)
{ } {}
void up_to_internal_size(bool v) void up_to_internal_size(bool v)
{ up_to_internal_size_ = v; } { up_to_internal_size_ = v; }
@@ -105,11 +105,12 @@ public:
void enqueue(std::string const & kernel_prefix, std::vector<lazy_program_compiler> & programs, statements_container const & statements) void enqueue(std::string const & kernel_prefix, std::vector<lazy_program_compiler> & programs, statements_container const & statements)
{ {
atidlas_int_t size = input_sizes(statements)[0]; atidlas_int_t size = input_sizes(statements)[0];
viennacl::ocl::kernel * kernel; std::string kfallback = kernel_prefix;
if(p_.simd_width > 1 && (has_strided_access(statements) || (size%p_.simd_width>0) || has_misaligned_offset(statements))) kfallback+='0';
kernel = &programs[0].program().get_kernel(kernel_prefix+"0"); std::string kopt = kernel_prefix;
else kopt+='1';
kernel = &programs[1].program().get_kernel(kernel_prefix+"1"); bool fallback = p_.simd_width > 1 && (has_strided_access(statements) || (size%p_.simd_width>0) || has_misaligned_offset(statements));
viennacl::ocl::kernel * kernel = &programs[fallback?0:1].program().get_kernel(fallback?kfallback:kopt);
kernel->local_work_size(0, p_.local_size_0); kernel->local_work_size(0, p_.local_size_0);
kernel->global_work_size(0, p_.local_size_0*p_.num_groups); kernel->global_work_size(0, p_.local_size_0*p_.num_groups);
unsigned int current_arg = 0; unsigned int current_arg = 0;

View File

@@ -51,9 +51,10 @@ def do_tuning(args):
for operation in ['vector-axpy', 'reduction', 'matrix-axpy', 'row-wise-reduction', 'matrix-product']: for operation in ['vector-axpy', 'reduction', 'matrix-axpy', 'row-wise-reduction', 'matrix-product']:
for datatype in [vcl.float32, vcl.float64]: for datatype in [vcl.float32, vcl.float64]:
if not any(x in args.operations for x in [operation, operation + '-' + datatype.__name__]): if operation not in args.operations and operation + '-' + datatype.__name__ not in args.operations:
continue continue
ctx = cl.Context([device]) ctx = cl.Context([device])
@@ -78,22 +79,25 @@ def do_tuning(args):
def log_uniform_sample(a,b): def log_uniform_sample(a,b):
return np.exp(np.random.uniform(low=np.log(a), high=np.log(b), size=1)).astype(int) return np.exp(np.random.uniform(low=np.log(a), high=np.log(b), size=1)).astype(int)
def log_space_gen_product(a,b,N,dim): def space_gen_product(a,b,N,dim,method):
N = int(N**(1.0/dim)) N = int(N**(1.0/dim))
def log_space_gen(a,b): def space_gen(a,b,method):
for i in range(N): for i in range(N):
v = int(np.exp(np.log(a) + (np.log(b) - np.log(a))*(i+1)/N)) if method == 'linear':
v = int(a + (b-a)*i/N)
if method == 'log':
v = int(np.exp(np.log(a) + (np.log(b) - np.log(a))*i/N))
yield (v//64 + 1)*64 yield (v//64 + 1)*64
return tuple(itertools.product(*[space_gen(a,b,method) for i in range(dim)]))
return tuple(itertools.product(*[log_space_gen(a,b) for i in range(dim)]))
#Helper for tuning #Helper for tuning
def tune(execution_handler, a, b, dimsample, additional_parameters): def tune(execution_handler, a, b, dimsample, layouts, sample_method_profiles, sample_method_dataset):
print args.build_model
print('-----') print('-----')
print(' '.join(map(str, ("Now tuning:", datatype.__name__, '-', operation, '-'.join(additional_parameters), '[' + device.name, '(' + device.platform.name + ')]')))) print(' '.join(map(str, ("Now tuning:", datatype.__name__, '-', operation, '-'.join(layouts), '[' + device.name, '(' + device.platform.name + ')]'))))
#Update JSON #Update JSON
full_operation = operation + ''.join(additional_parameters) full_operation = operation + ''.join(layouts)
if full_operation not in json_out: if full_operation not in json_out:
json_out[full_operation] = {} json_out[full_operation] = {}
json_out[full_operation][datatype.__name__] = {} json_out[full_operation][datatype.__name__] = {}
@@ -105,14 +109,14 @@ def do_tuning(args):
else: else:
def compute_perf(x, t): def compute_perf(x, t):
return TYPES[operation]['perf-index']([datatype().itemsize, x, t]) return TYPES[operation]['perf-index']([datatype().itemsize, x, t])
profiles_generator = log_space_gen_product(a, b, args.sample_size, dimsample) profiles_generator = space_gen_product(a, b, args.sample_size, dimsample, sample_method_profiles)
# profiles = dataset.sample_profiles(execution_handler, profiles_generator) profiles = dataset.sample_profiles(execution_handler, profiles_generator)
if args.build_model: if args.build_model:
dataset_generator = log_space_gen_product(a, b, 1000, dimsample) dataset_generator = space_gen_product(a, b, 1000, dimsample, sample_method_dataset)
# X, Y, profiles = dataset.sample_dataset(os.path.join(full_operation,datatype.__name__), profiles, execution_handler, dataset_generator) X, Y, profiles = dataset.sample_dataset(os.path.join(full_operation,datatype.__name__), profiles, execution_handler, dataset_generator)
profiles = np.loadtxt('data/vector-axpy/float32/profiles.csv') # profiles = np.loadtxt('data/'+full_operation+'/'+datatype.__name__+'/profiles.csv')
X = np.loadtxt('data/vector-axpy/float32/X.csv',ndmin=2) # X = np.loadtxt('data/'+full_operation+'/'+datatype.__name__+'/X.csv',ndmin=2)
Y = np.loadtxt('data/vector-axpy/float32/Y.csv',ndmin=2) # Y = np.loadtxt('data/'+full_operation+'/'+datatype.__name__+'/Y.csv',ndmin=2)
clf = train_model(X, Y, profiles, TYPES[operation]['perf-measure']) clf = train_model(X, Y, profiles, TYPES[operation]['perf-measure'])
D['predictor'] = [{'children_left': e.tree_.children_left.tolist(), D['predictor'] = [{'children_left': e.tree_.children_left.tolist(),
'children_right': e.tree_.children_right.tolist(), 'children_right': e.tree_.children_right.tolist(),
@@ -120,7 +124,7 @@ def do_tuning(args):
'feature': e.tree_.feature.astype('float64').tolist(), 'feature': e.tree_.feature.astype('float64').tolist(),
'value': e.tree_.value[:,:,0].astype('float64').tolist()} for e in clf.estimators_] 'value': e.tree_.value[:,:,0].astype('float64').tolist()} for e in clf.estimators_]
if args.viennacl_src_path: if args.viennacl_src_path:
misc_tools.update_viennacl_headers(args.viennacl_src_path, device,datatype,operation,additional_parameters,profiles[0]) misc_tools.update_viennacl_headers(args.viennacl_src_path, device,datatype,operation,layouts,profiles[0])
D['profiles'] = [map(int, x) for x in profiles] D['profiles'] = [map(int, x) for x in profiles]
@@ -130,7 +134,7 @@ def do_tuning(args):
x = vcl.Vector(sizes[0], context=ctx, dtype=datatype) x = vcl.Vector(sizes[0], context=ctx, dtype=datatype)
y = vcl.Vector(sizes[0], context=ctx, dtype=datatype) y = vcl.Vector(sizes[0], context=ctx, dtype=datatype)
return execute(device, vcl.Assign(y, x + y), (), sizes, fname, parameters) return execute(device, vcl.Assign(y, x + y), (), sizes, fname, parameters)
tune(execution_handler, 1e4, 2e7, 1, ()) tune(execution_handler, 1e3, 2e7, 1, (),'log', 'log')
#Reduction #Reduction
if operation=='reduction': if operation=='reduction':
def execution_handler(sizes, fname=os.devnull, parameters=None): def execution_handler(sizes, fname=os.devnull, parameters=None):
@@ -138,14 +142,14 @@ def do_tuning(args):
y = vcl.Vector(sizes[0], context=ctx, dtype=datatype) y = vcl.Vector(sizes[0], context=ctx, dtype=datatype)
s = vcl.Scalar(0, context=ctx, dtype=datatype) s = vcl.Scalar(0, context=ctx, dtype=datatype)
return execute(device, vcl.Assign(s, vcl.Dot(x,y)), (), sizes, fname, parameters) return execute(device, vcl.Assign(s, vcl.Dot(x,y)), (), sizes, fname, parameters)
tune(execution_handler, 1e4, 2e7, 1, ()) tune(execution_handler, 1e3, 2e7, 1, (),'log', 'log')
#Matrix AXPY #Matrix AXPY
if operation=='matrix-axpy': if operation=='matrix-axpy':
def execution_handler(sizes, fname=os.devnull, parameters=None): def execution_handler(sizes, fname=os.devnull, parameters=None):
A = vcl.Matrix(sizes, context=ctx, dtype=datatype, layout=vcl.COL_MAJOR) A = vcl.Matrix(sizes, context=ctx, dtype=datatype, layout=vcl.COL_MAJOR)
C = vcl.Matrix(sizes, context=ctx, dtype=datatype, layout=vcl.COL_MAJOR) C = vcl.Matrix(sizes, context=ctx, dtype=datatype, layout=vcl.COL_MAJOR)
return execute(device, vcl.Assign(C,A + C), (), sizes, fname, parameters) return execute(device, vcl.Assign(C,A + C), (), sizes, fname, parameters)
tune(execution_handler, 100, 4000, 2, ()) tune(execution_handler, 100, 5000, 2, (),'log', 'log')
#Row-wise reduction #Row-wise reduction
if operation=='row-wise-reduction': if operation=='row-wise-reduction':
for A_trans in args.gemv_layouts: for A_trans in args.gemv_layouts:
@@ -155,7 +159,7 @@ def do_tuning(args):
y = vcl.Vector(sizes[0], context=ctx, dtype=datatype) y = vcl.Vector(sizes[0], context=ctx, dtype=datatype)
LHS = A if A_trans=='N' else A.T LHS = A if A_trans=='N' else A.T
return execute(device, vcl.Assign(y, LHS*x), (), sizes, fname, parameters) return execute(device, vcl.Assign(y, LHS*x), (), sizes, fname, parameters)
tune(execution_handler, 100, 4000, 2, (A_trans,)) tune(execution_handler, 100, 5000, 2, (A_trans,),'log', 'log')
#Matrix Product #Matrix Product
if operation=='matrix-product': if operation=='matrix-product':
for L in args.gemm_layouts: for L in args.gemm_layouts:
@@ -170,7 +174,7 @@ def do_tuning(args):
beta = vcl.HostScalar(1.0, context=ctx, dtype = datatype) beta = vcl.HostScalar(1.0, context=ctx, dtype = datatype)
C = vcl.Matrix((sizes[0], sizes[1]), context=ctx, dtype = datatype, layout=vcl.COL_MAJOR) C = vcl.Matrix((sizes[0], sizes[1]), context=ctx, dtype = datatype, layout=vcl.COL_MAJOR)
return execute(device, vcl.Assign(C,LHS*RHS*alpha + C*beta),(A_trans,B_trans), sizes, fname, parameters) return execute(device, vcl.Assign(C,LHS*RHS*alpha + C*beta),(A_trans,B_trans), sizes, fname, parameters)
tune(execution_handler, 100, 2000, 3,(A_trans,B_trans)) tune(execution_handler, 100, 2000, 3,(A_trans,B_trans), 'linear')
json.dump(json_out, open(args.json_file,'w')) json.dump(json_out, open(args.json_file,'w'))
@@ -227,10 +231,12 @@ class ArgumentsHandler:
args = parser.parse_args() args = parser.parse_args()
self.__dict__ = args.__dict__.copy() self.__dict__ = args.__dict__.copy()
#Retypes #Retypes
self.operations = [self.operations] if not isinstance(self.operations, list) else self.operations
self.device = devices[int(self.device)] self.device = devices[int(self.device)]
if not self.json_file:
self.json_file = misc_tools.sanitize_string(self.device.name) + '.json'
self.gemm_layouts = self.gemm_layouts.split(',') self.gemm_layouts = self.gemm_layouts.split(',')
self.gemv_layouts = self.gemv_layouts.split(',') self.gemv_layouts = self.gemv_layouts.split(',')
if self.method == 'simple': if self.method == 'simple':

View File

@@ -132,7 +132,7 @@ class GeneticOperators(object):
tt = misc_tools.benchmark(template, self.statement, self.device) tt = misc_tools.benchmark(template, self.statement, self.device)
self.out.write(','.join([str(tt)]+map(str,map(int,parameters)))+'\n') self.out.write(','.join([str(tt)]+map(str,map(int,parameters)))+'\n')
self.cache[tuple(individual)] = tt self.cache[tuple(individual)] = tt
except: except ValueError:
self.cache[tuple(individual)] = 10 self.cache[tuple(individual)] = 10
return self.cache[tuple(individual)], return self.cache[tuple(individual)],
@@ -161,9 +161,14 @@ class GeneticOperators(object):
for _ in xrange(mu): for _ in xrange(mu):
op_choice = random.random() op_choice = random.random()
if op_choice < cxpb: # Apply crossover if op_choice < cxpb: # Apply crossover
ind1, ind2 = map(self.toolbox.clone, random.sample(population, 2)) while True:
ind1, ind2 = self.toolbox.mate(ind1, ind2) ind1, ind2 = map(self.toolbox.clone, random.sample(population, 2))
del ind1.fitness.values ind1, ind2 = self.toolbox.mate(ind1, ind2)
del ind1.fitness.values
parameters = self.decode(ind1)
template = self.build_template(self.TemplateType.Parameters(*parameters))
if not misc_tools.skip(template, self.statement, self.device):
break
offspring.append(ind1) offspring.append(ind1)
elif op_choice < cxpb + mutpb: # Apply mutation elif op_choice < cxpb + mutpb: # Apply mutation
ind = self.toolbox.clone(random.choice(population)) ind = self.toolbox.clone(random.choice(population))

View File

@@ -7,6 +7,7 @@ import sys
import pyopencl as cl import pyopencl as cl
import pyviennacl as vcl import pyviennacl as vcl
import pyatidlas as atd
import numpy as np import numpy as np
class PhysicalLimitsNV: class PhysicalLimitsNV:
@@ -214,13 +215,16 @@ def benchmark(template, statement, device):
if occupancy_record.occupancy < 15 : if occupancy_record.occupancy < 15 :
raise ValueError("Template has too low occupancy") raise ValueError("Template has too low occupancy")
else: else:
template.execute(statement, True) vcl_statements = statements.vcl_tuple
vcl_context = statement.result.context.vcl_sub_context
model = atd._atidlas.model(template._vcl_template, vcl_context, vcl_context.current_device)
model.execute(vcl_statements, False, True)
statement.result.context.finish_all_queues() statement.result.context.finish_all_queues()
current_time = 0 current_time = 0
timings = [] timings = []
while current_time < 1e-1: while current_time < 1e-1:
time_before = time.time() time_before = time.time()
template.execute(statement,False) model.execute(vcl_statements, False, False)
statement.result.context.finish_all_queues() statement.result.context.finish_all_queues()
timings.append(time.time() - time_before) timings.append(time.time() - time_before)
current_time = current_time + timings[-1] current_time = current_time + timings[-1]

View File

@@ -1,5 +1,6 @@
from sklearn import tree from sklearn import tree
from sklearn import ensemble from sklearn import ensemble
from sklearn.grid_search import GridSearchCV
import numpy as np import numpy as np
def gmean(a, axis=0, dtype=None): def gmean(a, axis=0, dtype=None):
@@ -20,7 +21,6 @@ def nrmse(y_ground, y):
return rmsd/(np.max(y_ground) - np.min(y_ground)) return rmsd/(np.max(y_ground) - np.min(y_ground))
def train_model(X, Y, profiles, metric): def train_model(X, Y, profiles, metric):
#Shuffle
p = np.random.permutation(X.shape[0]) p = np.random.permutation(X.shape[0])
X = X[p,:] X = X[p,:]
Y = Y[p,:] Y = Y[p,:]
@@ -28,18 +28,34 @@ def train_model(X, Y, profiles, metric):
Ymax = np.max(Y) Ymax = np.max(Y)
Y = Y/Ymax Y = Y/Ymax
#Train the model #Train the model
cut = int(0.9*X.shape[0]) cut = int(0.95*X.shape[0])
nrmses = {}
for depth in range(1,10):
clf = ensemble.RandomForestRegressor(5, max_depth=4).fit(X[:cut,:], Y[:cut,:])
t = np.argmin(clf.predict(X[cut:,:]), axis = 1)
y = np.array([Y[cut+i,t[i]] for i in range(t.size)])
y_ground = np.min(Y[cut:,:], axis=1)
# for i in range(t.size):
# print X[cut+i,:], y[i], y_ground[i]
nrmses[clf] = nrmse(y_ground, y)
print depth, nrmses[clf]
XTr, YTr = X[:cut,:], Y[:cut,:]
XCv, YCv = X[cut:,:], Y[cut:,:]
nrmses = {}
for N in range(1,10):
for depth in range(1,5):
clf = ensemble.RandomForestRegressor(N, max_depth=depth).fit(XTr, YTr)
t = np.argmin(clf.predict(XCv), axis = 1)
y = np.array([YCv[i,t[i]] for i in range(t.size)])
nrmses[clf] = nrmse(np.min(YCv[:,:], axis=1), y)
clf = min(nrmses, key=nrmses.get) clf = min(nrmses, key=nrmses.get)
t = np.argmin(clf.predict(XCv), axis = 1)
s = np.array([y[0]/y[k] for y,k in zip(YCv, t)])
tt = np.argmin(YCv, axis = 1)
ss = np.array([y[0]/y[k] for y,k in zip(YCv, tt)])
p5 = lambda a: np.percentile(a, 5)
p25 = lambda a: np.percentile(a, 25)
p50 = lambda a: np.percentile(a, 50)
p75 = lambda a: np.percentile(a, 75)
p95 = lambda a: np.percentile(a, 95)
print("Percentile :\t 5 \t 25 \t 50 \t 75 \t 95")
print("Testing speedup:\t %.2f\t %.2f\t %.2f\t %.2f\t %.3f"%(p5(s), p25(s), p50(s), p75(s), p95(s)))
print("Optimal speedup:\t %.2f\t %.2f\t %.2f\t %.2f\t %.3f"%(p5(ss), p25(ss), p50(ss), p75(ss), p95(ss)))
print clf
return clf return clf

View File

@@ -2,7 +2,7 @@ set(SETUP_PY_IN "${CMAKE_CURRENT_SOURCE_DIR}/setup.py")
set(SETUP_PY "${CMAKE_CURRENT_BINARY_DIR}/setup.py") set(SETUP_PY "${CMAKE_CURRENT_BINARY_DIR}/setup.py")
set(OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/build/timestamp") set(OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/build/timestamp")
file(GLOB DEPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "${CMAKE_CURRENT_SOURCE_DIR}/pyatidlas/*.py ${CMAKE_CURRENT_SOURCE_DIR}/src/*.cpp") file(GLOB DEPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "${CMAKE_CURRENT_SOURCE_DIR}/pyatidlas/*.py ${CMAKE_CURRENT_SOURCE_DIR}/src/*.cpp")
list(APPEND DEPS "${CMAKE_CURRENT_SOURCE_DIR}/setup.py") list(APPEND DEPS "${CMAKE_CURRENT_SOURCE_DIR}/setup.py" "${CMAKE_CURRENT_SOURCE_DIR}/src/_atidlas.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/pyatidlas/pycore.py")
configure_file(${SETUP_PY_IN} ${SETUP_PY}) configure_file(${SETUP_PY_IN} ${SETUP_PY})
add_custom_command(OUTPUT ${OUTPUT} add_custom_command(OUTPUT ${OUTPUT}

View File

@@ -44,10 +44,10 @@ def main():
DEFINES = [('VIENNACL_WITH_OPENCL',None), ('VIENNACL_WITH_OPENMP', None), DEFINES = [('VIENNACL_WITH_OPENCL',None), ('VIENNACL_WITH_OPENMP', None),
('boost','pyviennaclboost')] ('boost','pyviennaclboost')]
INCLUDE_DIRS = ['/home/philippe/Development/pyviennacl-dev/external/boost-python-ublas-subset/boost_subset/', INCLUDE_DIRS = ['${CMAKE_CURRENT_SOURCE_DIR}/external/pyviennacl-dev/external/boost-python-ublas-subset/boost_subset/',
'${PROJECT_SOURCE_DIR}', '${PROJECT_SOURCE_DIR}',
'/home/philippe/Development/pyviennacl-dev/external/viennacl-dev'] '${CMAKE_CURRENT_SOURCE_DIR}/external/pyviennacl-dev/external/viennacl-dev']
LIBRARY_DIRS = ['/home/philippe/Development/pyviennacl-dev/build/lib.linux-x86_64-2.7/pyviennacl/'] LIBRARY_DIRS = ['${CMAKE_CURRENT_SOURCE_DIR}/external/pyviennacl-dev/build/lib.linux-x86_64-2.7/pyviennacl/']
setup( setup(
name="pyatidlas", name="pyatidlas",

View File

@@ -10,18 +10,22 @@
#include "atidlas/templates/row_wise_reduction.hpp" #include "atidlas/templates/row_wise_reduction.hpp"
#include "atidlas/templates/matrix_product.hpp" #include "atidlas/templates/matrix_product.hpp"
#include "atidlas/execute.hpp" #include "atidlas/model/model.hpp"
#define ENUM_VALUE(NS, V) .value( #V, NS :: V ) #define ENUM_VALUE(NS, V) .value( #V, NS :: V )
namespace bp = boost::python; namespace bp = boost::python;
namespace vcl = viennacl; namespace vcl = viennacl;
namespace atd = atidlas;
void export_atidlas() void export_atidlas()
{ {
bp::def("execute", &atidlas::execute);
bp::class_<atidlas::model>("model", bp::init<atd::template_base const &, vcl::ocl::context &, vcl::ocl::device const & >())
.def("execute", &atd::model::execute)
;
bp::enum_<atidlas::fetching_policy_type> bp::enum_<atidlas::fetching_policy_type>
("fetching_policy_type") ("fetching_policy_type")
ENUM_VALUE(atidlas, FETCH_FROM_LOCAL) ENUM_VALUE(atidlas, FETCH_FROM_LOCAL)
@@ -107,4 +111,4 @@ BOOST_PYTHON_MODULE(_atidlas)
bp::object package = bp::scope(); bp::object package = bp::scope();
package.attr("__path__") = "_atidlas"; package.attr("__path__") = "_atidlas";
export_atidlas(); export_atidlas();
} }