2014-10-15 04:24:25 -04:00
|
|
|
from . import _atidlas as _atd
|
2014-10-14 23:49:18 -04:00
|
|
|
|
|
|
|
|
2014-10-15 04:24:25 -04:00
|
|
|
FetchingPolicy = _atd.fetching_policy_type
|
2014-10-14 23:49:18 -04:00
|
|
|
|
|
|
|
class TemplateBase(object):
|
|
|
|
|
2014-10-15 04:24:25 -04:00
|
|
|
Parameters = _atd.template_base.parameters_type
|
2014-10-14 23:49:18 -04:00
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
pass
|
|
|
|
|
|
|
|
@property
|
|
|
|
def parameters(self):
|
|
|
|
return self._vcl_template.parameters()
|
|
|
|
|
|
|
|
def lmem_usage(self, statements):
|
|
|
|
return self._vcl_template.lmem_usage(statements.vcl_tuple)
|
|
|
|
|
|
|
|
def registers_usage(self, statements):
|
|
|
|
return self._vcl_template.registers_usage(statements.vcl_tuple)
|
|
|
|
|
|
|
|
def check(self, statement):
|
|
|
|
vcl_statement = statement.vcl_statement;
|
|
|
|
vcl_context = statement.result.context.vcl_sub_context;
|
|
|
|
return vcl_statement.check_template(self._vcl_template, vcl_context);
|
|
|
|
|
|
|
|
def execute(self, statement, force_compilation=False):
|
|
|
|
vcl_statement = statement.vcl_statement;
|
|
|
|
vcl_context = statement.result.context.vcl_sub_context;
|
|
|
|
vcl_statement.execute_template(self._vcl_template, vcl_context, force_compilation);
|
|
|
|
return statement.result;
|
|
|
|
|
|
|
|
|
|
|
|
class VectorAxpyTemplate(TemplateBase):
|
|
|
|
|
2014-10-15 04:24:25 -04:00
|
|
|
Parameters = _atd.vector_axpy_template.parameters_type
|
2014-10-14 23:49:18 -04:00
|
|
|
|
|
|
|
def __init__(self, parameters):
|
|
|
|
super(VectorAxpyTemplate, self).__init__()
|
2014-10-15 04:24:25 -04:00
|
|
|
self._vcl_template = _atd.vector_axpy_template(parameters)
|
2014-10-14 23:49:18 -04:00
|
|
|
|
|
|
|
|
|
|
|
class MatrixAxpyTemplate(TemplateBase):
|
|
|
|
|
2014-10-15 04:24:25 -04:00
|
|
|
Parameters = _atd.matrix_axpy_template.parameters_type
|
2014-10-14 23:49:18 -04:00
|
|
|
|
|
|
|
def __init__(self, parameters):
|
|
|
|
super(MatrixAxpyTemplate, self).__init__()
|
2014-10-15 04:24:25 -04:00
|
|
|
self._vcl_template = _atd.matrix_axpy_template(parameters)
|
2014-10-14 23:49:18 -04:00
|
|
|
|
|
|
|
|
|
|
|
class ReductionTemplate(TemplateBase):
|
|
|
|
|
2014-10-15 04:24:25 -04:00
|
|
|
Parameters = _atd.reduction_template.parameters_type
|
2014-10-14 23:49:18 -04:00
|
|
|
|
|
|
|
def __init__(self, parameters):
|
|
|
|
super(ReductionTemplate, self).__init__()
|
2014-10-15 04:24:25 -04:00
|
|
|
self._vcl_template = _atd.reduction_template(parameters)
|
2014-10-14 23:49:18 -04:00
|
|
|
|
|
|
|
class RowWiseReductionTemplate(TemplateBase):
|
|
|
|
|
2014-10-15 04:24:25 -04:00
|
|
|
Parameters = _atd.row_wise_reduction_template.parameters_type
|
2014-10-14 23:49:18 -04:00
|
|
|
|
|
|
|
def __init__(self, parameters):
|
|
|
|
super(RowWiseReductionTemplate, self).__init__()
|
2014-10-15 04:24:25 -04:00
|
|
|
self._vcl_template = _atd.row_wise_reduction_template(parameters)
|
2014-10-14 23:49:18 -04:00
|
|
|
|
|
|
|
|
|
|
|
class MatrixProductTemplate(TemplateBase):
|
|
|
|
|
2014-10-15 04:24:25 -04:00
|
|
|
Parameters = _atd.matrix_product_template.parameters_type
|
2014-10-14 23:49:18 -04:00
|
|
|
|
|
|
|
def __init__(self, parameters, A_trans, B_trans):
|
|
|
|
super(MatrixProductTemplate, self).__init__();
|
|
|
|
self._A_trans = A_trans
|
|
|
|
self._B_trans = B_trans
|
2014-10-15 04:24:25 -04:00
|
|
|
self._vcl_template = _atd.matrix_product_template(parameters, A_trans, B_trans)
|
2014-10-14 23:49:18 -04:00
|
|
|
|
|
|
|
@property
|
|
|
|
def A_trans(self):
|
|
|
|
return self._A_trans
|
|
|
|
|
|
|
|
@property
|
|
|
|
def B_trans(self):
|
|
|
|
return self._B_trans
|