Python: fixed compilation issues

This commit is contained in:
Philippe Tillet
2016-08-13 09:41:04 -07:00
parent fd5c6d3915
commit 5178ba06f9
3 changed files with 20 additions and 20 deletions

View File

@@ -49,8 +49,8 @@ void export_templates()
bp::scope template_scope = templates_module; bp::scope template_scope = templates_module;
bp::enum_<tpt::fetching_policy_type> bp::enum_<tpt::fetch_type>
("fetching_policy_type") ("fetch_type")
.value("FETCH_FROM_LOCAL", tpt::FETCH_FROM_LOCAL) .value("FETCH_FROM_LOCAL", tpt::FETCH_FROM_LOCAL)
.value("FETCH_FROM_GLOBAL_STRIDED", tpt::FETCH_FROM_GLOBAL_STRIDED) .value("FETCH_FROM_GLOBAL_STRIDED", tpt::FETCH_FROM_GLOBAL_STRIDED)
.value("FETCH_FROM_GLOBAL_CONTIGUOUS", tpt::FETCH_FROM_GLOBAL_CONTIGUOUS); .value("FETCH_FROM_GLOBAL_CONTIGUOUS", tpt::FETCH_FROM_GLOBAL_CONTIGUOUS);
@@ -69,25 +69,25 @@ void export_templates()
} }
#define WRAP_BASE(name) bp::class_<tpt::base_impl<tpt::name, tpt::name::parameters_type>, bp::bases<tpt::base>, boost::noncopyable>(#name, bp::no_init)\ #define WRAP_BASE(name) bp::class_<tpt::base_impl<tpt::name, tpt::name::parameters_type>, bp::bases<tpt::base>, boost::noncopyable>(#name, bp::no_init)\
.add_property("local_size_0", &tpt::base_impl<tpt::name, tpt::name::parameters_type>::local_size_0)\ .add_property("ls0", &tpt::base_impl<tpt::name, tpt::name::parameters_type>::ls0)\
.add_property("local_size_1", &tpt::base_impl<tpt::name, tpt::name::parameters_type>::local_size_1); .add_property("ls1", &tpt::base_impl<tpt::name, tpt::name::parameters_type>::ls1);
#define WRAP_TEMPLATE(name, basename, ...) bp::class_<tpt::name, bp::bases<tpt::base_impl<tpt::basename, tpt::basename::parameters_type> > >(#name, bp::init<__VA_ARGS__>())\ #define WRAP_TEMPLATE(name, basename, ...) bp::class_<tpt::name, bp::bases<tpt::base_impl<tpt::basename, tpt::basename::parameters_type> > >(#name, bp::init<__VA_ARGS__>())\
; ;
#define WRAP_SINGLE_TEMPLATE(name, ...) WRAP_BASE(name) WRAP_TEMPLATE(name, name, __VA_ARGS__) #define WRAP_SINGLE_TEMPLATE(name, ...) WRAP_BASE(name) WRAP_TEMPLATE(name, name, __VA_ARGS__)
//Vector AXPY //Vector AXPY
WRAP_SINGLE_TEMPLATE(elementwise_1d, uint, uint, uint, tpt::fetching_policy_type) WRAP_SINGLE_TEMPLATE(elementwise_1d, uint, uint, uint, tpt::fetch_type)
WRAP_SINGLE_TEMPLATE(elementwise_2d, uint, uint, uint, uint, uint, tpt::fetching_policy_type) WRAP_SINGLE_TEMPLATE(elementwise_2d, uint, uint, uint, uint, uint, tpt::fetch_type)
WRAP_SINGLE_TEMPLATE(reduce_1d, uint, uint, uint, tpt::fetching_policy_type) WRAP_SINGLE_TEMPLATE(reduce_1d, uint, uint, uint, tpt::fetch_type)
WRAP_BASE(reduce_2d) WRAP_BASE(reduce_2d)
WRAP_TEMPLATE(reduce_2d_rows, reduce_2d, uint, uint, uint, uint, uint, tpt::fetching_policy_type) WRAP_TEMPLATE(reduce_2d_rows, reduce_2d, uint, uint, uint, uint, uint, tpt::fetch_type)
WRAP_TEMPLATE(reduce_2d_cols, reduce_2d, uint, uint, uint, uint, uint, tpt::fetching_policy_type) WRAP_TEMPLATE(reduce_2d_cols, reduce_2d, uint, uint, uint, uint, uint, tpt::fetch_type)
WRAP_BASE(matrix_product) WRAP_BASE(matrix_product)
WRAP_TEMPLATE(matrix_product_nn, matrix_product, uint, uint, uint, uint, uint, uint, uint, uint, tpt::fetching_policy_type, tpt::fetching_policy_type, uint, uint) WRAP_TEMPLATE(matrix_product_nn, matrix_product, uint, uint, uint, uint, uint, uint, uint, uint, tpt::fetch_type, tpt::fetch_type, uint, uint)
WRAP_TEMPLATE(matrix_product_tn, matrix_product, uint, uint, uint, uint, uint, uint, uint, uint, tpt::fetching_policy_type, tpt::fetching_policy_type, uint, uint) WRAP_TEMPLATE(matrix_product_tn, matrix_product, uint, uint, uint, uint, uint, uint, uint, uint, tpt::fetch_type, tpt::fetch_type, uint, uint)
WRAP_TEMPLATE(matrix_product_nt, matrix_product, uint, uint, uint, uint, uint, uint, uint, uint, tpt::fetching_policy_type, tpt::fetching_policy_type, uint, uint) WRAP_TEMPLATE(matrix_product_nt, matrix_product, uint, uint, uint, uint, uint, uint, uint, uint, tpt::fetch_type, tpt::fetch_type, uint, uint)
WRAP_TEMPLATE(matrix_product_tt, matrix_product, uint, uint, uint, uint, uint, uint, uint, uint, tpt::fetching_policy_type, tpt::fetching_policy_type, uint, uint) WRAP_TEMPLATE(matrix_product_tt, matrix_product, uint, uint, uint, uint, uint, uint, uint, uint, tpt::fetch_type, tpt::fetch_type, uint, uint)
} }

View File

@@ -34,10 +34,10 @@ import tools
from tools import profile_execution_failure from tools import profile_execution_failure
from time import sleep from time import sleep
fetch_types = [sc.templates.fetching_policy_type.FETCH_FROM_GLOBAL_CONTIGUOUS, fetch_types = [sc.templates.fetch_type.FETCH_FROM_GLOBAL_CONTIGUOUS,
sc.templates.fetching_policy_type.FETCH_FROM_GLOBAL_STRIDED, sc.templates.fetch_type.FETCH_FROM_GLOBAL_STRIDED,
sc.templates.fetching_policy_type.FETCH_FROM_LOCAL, sc.templates.fetch_type.FETCH_FROM_LOCAL,
sc.templates.fetching_policy_type.FETCH_FROM_LOCAL] sc.templates.fetch_type.FETCH_FROM_LOCAL]
def exhaustive(template, sizes, context): def exhaustive(template, sizes, context):
tree, _ = tools.tree_of(template, sizes, context) tree, _ = tools.tree_of(template, sizes, context)

View File

@@ -152,11 +152,11 @@ class Tuner:
with open(os.path.join(savepath, 'profiles.csv')) as f: with open(os.path.join(savepath, 'profiles.csv')) as f:
def mmap(x): def mmap(x):
if x=='FETCH_FROM_LOCAL': if x=='FETCH_FROM_LOCAL':
return sc.templates.fetching_policy_type.FETCH_FROM_LOCAL return sc.templates.fetch_type.FETCH_FROM_LOCAL
if x=='FETCH_FROM_GLOBAL_CONTIGUOUS': if x=='FETCH_FROM_GLOBAL_CONTIGUOUS':
return sc.templates.fetching_policy_type.FETCH_FROM_GLOBAL_CONTIGUOUS return sc.templates.fetch_type.FETCH_FROM_GLOBAL_CONTIGUOUS
if x=='FETCH_FROM_GLOBAL_STRIDED': if x=='FETCH_FROM_GLOBAL_STRIDED':
return sc.templates.fetching_policy_type.FETCH_FROM_GLOBAL_STRIDED return sc.templates.fetch_type.FETCH_FROM_GLOBAL_STRIDED
return int(x) return int(x)
profiles = [map(mmap,row) for v in row for row in csv.reader(f, delimiter=',')] profiles = [map(mmap,row) for v in row for row in csv.reader(f, delimiter=',')]
except: except: