low level representation of array
This commit is contained in:
@@ -161,9 +161,9 @@ namespace atidlas
|
||||
|
||||
//Init
|
||||
expression_type current_type;
|
||||
if(root_save.lhs.array->nshape()==0)
|
||||
if(root_save.lhs.array.shape1==1 && root_save.lhs.array.shape2==1)
|
||||
current_type = SCALAR_AXPY_TYPE;
|
||||
else if(root_save.lhs.array->nshape()==1)
|
||||
else if(root_save.lhs.array.shape1==1 || root_save.lhs.array.shape2==1)
|
||||
current_type=VECTOR_AXPY_TYPE;
|
||||
else
|
||||
current_type=MATRIX_AXPY_TYPE;
|
||||
@@ -186,15 +186,15 @@ namespace atidlas
|
||||
case SCALAR_AXPY_TYPE:
|
||||
case REDUCTION_TYPE: tmp = tools::shared_ptr<obj_base>(new array(1, dtype, context)); break;
|
||||
|
||||
case VECTOR_AXPY_TYPE: tmp = tools::shared_ptr<obj_base>(new array(lmost.lhs.array->shape()._1, dtype, context)); break;
|
||||
case ROW_WISE_REDUCTION_TYPE: tmp = tools::shared_ptr<obj_base>(new array(lmost.lhs.array->shape()._1, dtype, context)); break;
|
||||
case COL_WISE_REDUCTION_TYPE: tmp = tools::shared_ptr<obj_base>(new array(lmost.lhs.array->shape()._2, dtype, context)); break;
|
||||
case VECTOR_AXPY_TYPE: tmp = tools::shared_ptr<obj_base>(new array(lmost.lhs.array.shape1, dtype, context)); break;
|
||||
case ROW_WISE_REDUCTION_TYPE: tmp = tools::shared_ptr<obj_base>(new array(lmost.lhs.array.shape1, dtype, context)); break;
|
||||
case COL_WISE_REDUCTION_TYPE: tmp = tools::shared_ptr<obj_base>(new array(lmost.lhs.array.shape2, dtype, context)); break;
|
||||
|
||||
case MATRIX_AXPY_TYPE: tmp = tools::shared_ptr<obj_base>(new array(lmost.lhs.array->shape()._1, lmost.lhs.array->shape()._2, dtype, context)); break;
|
||||
case MATRIX_PRODUCT_NN_TYPE: tmp = tools::shared_ptr<obj_base>(new array(node.lhs.array->shape()._1, node.rhs.array->shape()._2, dtype, context)); break;
|
||||
case MATRIX_PRODUCT_NT_TYPE: tmp = tools::shared_ptr<obj_base>(new array(node.lhs.array->shape()._1, node.rhs.array->shape()._1, dtype, context)); break;
|
||||
case MATRIX_PRODUCT_TN_TYPE: tmp = tools::shared_ptr<obj_base>(new array(node.lhs.array->shape()._2, node.rhs.array->shape()._2, dtype, context)); break;
|
||||
case MATRIX_PRODUCT_TT_TYPE: tmp = tools::shared_ptr<obj_base>(new array(node.lhs.array->shape()._2, node.rhs.array->shape()._1, dtype, context)); break;
|
||||
case MATRIX_AXPY_TYPE: tmp = tools::shared_ptr<obj_base>(new array(lmost.lhs.array.shape1, lmost.lhs.array.shape2, dtype, context)); break;
|
||||
case MATRIX_PRODUCT_NN_TYPE: tmp = tools::shared_ptr<obj_base>(new array(node.lhs.array.shape1, node.rhs.array.shape2, dtype, context)); break;
|
||||
case MATRIX_PRODUCT_NT_TYPE: tmp = tools::shared_ptr<obj_base>(new array(node.lhs.array.shape1, node.rhs.array.shape1, dtype, context)); break;
|
||||
case MATRIX_PRODUCT_TN_TYPE: tmp = tools::shared_ptr<obj_base>(new array(node.lhs.array.shape2, node.rhs.array.shape2, dtype, context)); break;
|
||||
case MATRIX_PRODUCT_TT_TYPE: tmp = tools::shared_ptr<obj_base>(new array(node.lhs.array.shape2, node.rhs.array.shape1, dtype, context)); break;
|
||||
|
||||
default: throw "This shouldn't happen. Ever.";
|
||||
}
|
||||
@@ -213,7 +213,7 @@ namespace atidlas
|
||||
rit->second->dtype = dtype;
|
||||
rit->second->type_family = ARRAY_TYPE_FAMILY;
|
||||
rit->second->subtype = DENSE_ARRAY_TYPE;
|
||||
rit->second->array = (array*)tmp.get();
|
||||
fill((array&)*tmp, rit->second->array);
|
||||
}
|
||||
|
||||
/*-----Compute final expression-----*/
|
||||
|
Reference in New Issue
Block a user