low level representation of array

This commit is contained in:
Philippe Tillet
2015-01-18 14:52:45 -05:00
parent 16648f18e0
commit edaa821d93
17 changed files with 243 additions and 194 deletions

View File

@@ -161,9 +161,9 @@ namespace atidlas
//Init
expression_type current_type;
if(root_save.lhs.array->nshape()==0)
if(root_save.lhs.array.shape1==1 && root_save.lhs.array.shape2==1)
current_type = SCALAR_AXPY_TYPE;
else if(root_save.lhs.array->nshape()==1)
else if(root_save.lhs.array.shape1==1 || root_save.lhs.array.shape2==1)
current_type=VECTOR_AXPY_TYPE;
else
current_type=MATRIX_AXPY_TYPE;
@@ -186,15 +186,15 @@ namespace atidlas
case SCALAR_AXPY_TYPE:
case REDUCTION_TYPE: tmp = tools::shared_ptr<obj_base>(new array(1, dtype, context)); break;
case VECTOR_AXPY_TYPE: tmp = tools::shared_ptr<obj_base>(new array(lmost.lhs.array->shape()._1, dtype, context)); break;
case ROW_WISE_REDUCTION_TYPE: tmp = tools::shared_ptr<obj_base>(new array(lmost.lhs.array->shape()._1, dtype, context)); break;
case COL_WISE_REDUCTION_TYPE: tmp = tools::shared_ptr<obj_base>(new array(lmost.lhs.array->shape()._2, dtype, context)); break;
case VECTOR_AXPY_TYPE: tmp = tools::shared_ptr<obj_base>(new array(lmost.lhs.array.shape1, dtype, context)); break;
case ROW_WISE_REDUCTION_TYPE: tmp = tools::shared_ptr<obj_base>(new array(lmost.lhs.array.shape1, dtype, context)); break;
case COL_WISE_REDUCTION_TYPE: tmp = tools::shared_ptr<obj_base>(new array(lmost.lhs.array.shape2, dtype, context)); break;
case MATRIX_AXPY_TYPE: tmp = tools::shared_ptr<obj_base>(new array(lmost.lhs.array->shape()._1, lmost.lhs.array->shape()._2, dtype, context)); break;
case MATRIX_PRODUCT_NN_TYPE: tmp = tools::shared_ptr<obj_base>(new array(node.lhs.array->shape()._1, node.rhs.array->shape()._2, dtype, context)); break;
case MATRIX_PRODUCT_NT_TYPE: tmp = tools::shared_ptr<obj_base>(new array(node.lhs.array->shape()._1, node.rhs.array->shape()._1, dtype, context)); break;
case MATRIX_PRODUCT_TN_TYPE: tmp = tools::shared_ptr<obj_base>(new array(node.lhs.array->shape()._2, node.rhs.array->shape()._2, dtype, context)); break;
case MATRIX_PRODUCT_TT_TYPE: tmp = tools::shared_ptr<obj_base>(new array(node.lhs.array->shape()._2, node.rhs.array->shape()._1, dtype, context)); break;
case MATRIX_AXPY_TYPE: tmp = tools::shared_ptr<obj_base>(new array(lmost.lhs.array.shape1, lmost.lhs.array.shape2, dtype, context)); break;
case MATRIX_PRODUCT_NN_TYPE: tmp = tools::shared_ptr<obj_base>(new array(node.lhs.array.shape1, node.rhs.array.shape2, dtype, context)); break;
case MATRIX_PRODUCT_NT_TYPE: tmp = tools::shared_ptr<obj_base>(new array(node.lhs.array.shape1, node.rhs.array.shape1, dtype, context)); break;
case MATRIX_PRODUCT_TN_TYPE: tmp = tools::shared_ptr<obj_base>(new array(node.lhs.array.shape2, node.rhs.array.shape2, dtype, context)); break;
case MATRIX_PRODUCT_TT_TYPE: tmp = tools::shared_ptr<obj_base>(new array(node.lhs.array.shape2, node.rhs.array.shape1, dtype, context)); break;
default: throw "This shouldn't happen. Ever.";
}
@@ -213,7 +213,7 @@ namespace atidlas
rit->second->dtype = dtype;
rit->second->type_family = ARRAY_TYPE_FAMILY;
rit->second->subtype = DENSE_ARRAY_TYPE;
rit->second->array = (array*)tmp.get();
fill((array&)*tmp, rit->second->array);
}
/*-----Compute final expression-----*/