Tuner: Moved tuning logic into the python wrapper - draft of Android tuning app using kivy
This commit is contained in:
28
python/external/sklearn/_tree.pyx
vendored
28
python/external/sklearn/_tree.pyx
vendored
@@ -23,11 +23,8 @@ from cpython cimport Py_INCREF, PyObject
|
||||
import numpy as np
|
||||
cimport numpy as np
|
||||
np.import_array()
|
||||
|
||||
from _tree cimport StackRecord
|
||||
|
||||
from scipy.sparse import issparse, csc_matrix, csr_matrix
|
||||
|
||||
cdef extern from "numpy/arrayobject.h":
|
||||
object PyArray_NewFromDescr(object subtype, np.dtype descr,
|
||||
int nd, np.npy_intp* dims,
|
||||
@@ -2063,9 +2060,6 @@ cdef class BaseSparseSplitter(Splitter):
|
||||
# Call parent init
|
||||
Splitter.init(self, X, y, sample_weight)
|
||||
|
||||
if not isinstance(X, csc_matrix):
|
||||
raise ValueError("X should be in csc format")
|
||||
|
||||
cdef SIZE_t* samples = self.samples
|
||||
cdef SIZE_t n_samples = self.n_samples
|
||||
|
||||
@@ -2790,18 +2784,7 @@ cdef class TreeBuilder:
|
||||
cdef inline _check_input(self, object X, np.ndarray y,
|
||||
np.ndarray sample_weight):
|
||||
"""Check input dtype, layout and format"""
|
||||
if issparse(X):
|
||||
X = X.tocsc()
|
||||
X.sort_indices()
|
||||
|
||||
if X.data.dtype != DTYPE:
|
||||
X.data = np.ascontiguousarray(X.data, dtype=DTYPE)
|
||||
|
||||
if X.indices.dtype != np.int32 or X.indptr.dtype != np.int32:
|
||||
raise ValueError("No support for np.int64 index based "
|
||||
"sparse matrices")
|
||||
|
||||
elif X.dtype != DTYPE:
|
||||
if X.dtype != DTYPE:
|
||||
# since we have to copy we will make it fortran for efficiency
|
||||
X = np.asfortranarray(X, dtype=DTYPE)
|
||||
|
||||
@@ -3430,10 +3413,7 @@ cdef class Tree:
|
||||
|
||||
cpdef np.ndarray apply(self, object X):
|
||||
"""Finds the terminal region (=leaf node) for each sample in X."""
|
||||
if issparse(X):
|
||||
return self._apply_sparse_csr(X)
|
||||
else:
|
||||
return self._apply_dense(X)
|
||||
return self._apply_dense(X)
|
||||
|
||||
|
||||
cdef inline np.ndarray _apply_dense(self, object X):
|
||||
@@ -3482,10 +3462,6 @@ cdef class Tree:
|
||||
"""Finds the terminal region (=leaf node) for each sample in sparse X.
|
||||
|
||||
"""
|
||||
# Check input
|
||||
if not isinstance(X, csr_matrix):
|
||||
raise ValueError("X should be in csr_matrix format, got %s"
|
||||
% type(X))
|
||||
|
||||
if X.dtype != DTYPE:
|
||||
raise ValueError("X.dtype should be np.float32, got %s" % X.dtype)
|
||||
|
Reference in New Issue
Block a user