mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-05-10 06:12:20 +00:00
feat: Add new components TextEmbeddingRetriever and MultiRetriever (#10872)
This commit is contained in:
committed by
GitHub
parent
02e0b42132
commit
2a4f104edf
@@ -11,9 +11,11 @@ _import_structure = {
|
||||
"auto_merging_retriever": ["AutoMergingRetriever"],
|
||||
"filter_retriever": ["FilterRetriever"],
|
||||
"in_memory": ["InMemoryBM25Retriever", "InMemoryEmbeddingRetriever"],
|
||||
"multi_retriever": ["MultiRetriever"],
|
||||
"multi_query_embedding_retriever": ["MultiQueryEmbeddingRetriever"],
|
||||
"multi_query_text_retriever": ["MultiQueryTextRetriever"],
|
||||
"sentence_window_retriever": ["SentenceWindowRetriever"],
|
||||
"text_embedding_retriever": ["TextEmbeddingRetriever"],
|
||||
}
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -23,7 +25,9 @@ if TYPE_CHECKING:
|
||||
from .in_memory import InMemoryEmbeddingRetriever as InMemoryEmbeddingRetriever
|
||||
from .multi_query_embedding_retriever import MultiQueryEmbeddingRetriever as MultiQueryEmbeddingRetriever
|
||||
from .multi_query_text_retriever import MultiQueryTextRetriever as MultiQueryTextRetriever
|
||||
from .multi_retriever import MultiRetriever as MultiRetriever
|
||||
from .sentence_window_retriever import SentenceWindowRetriever as SentenceWindowRetriever
|
||||
from .text_embedding_retriever import TextEmbeddingRetriever as TextEmbeddingRetriever
|
||||
|
||||
else:
|
||||
sys.modules[__name__] = LazyImporter(name=__name__, module_file=__file__, import_structure=_import_structure)
|
||||
|
||||
@@ -0,0 +1,284 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import Any
|
||||
|
||||
from haystack import component, default_from_dict, default_to_dict
|
||||
from haystack.components.retrievers.types.protocol import TextRetriever
|
||||
from haystack.core.serialization import component_from_dict, component_to_dict, import_class_by_name
|
||||
from haystack.dataclasses import Document
|
||||
from haystack.utils.experimental import _experimental
|
||||
from haystack.utils.misc import _deduplicate_documents
|
||||
|
||||
|
||||
@_experimental
|
||||
@component
|
||||
class MultiRetriever:
|
||||
"""
|
||||
A component that accepts text retrievers and runs them in parallel, combining their results.
|
||||
|
||||
> **Note:** This component is experimental and may change or be removed in future releases without prior
|
||||
deprecation notice.
|
||||
|
||||
All retrievers must implement the `TextRetriever` protocol. Use `TextEmbeddingRetriever` to wrap an
|
||||
embedding-based retriever before passing it to this component.
|
||||
|
||||
Each retriever is queried concurrently using a thread pool.
|
||||
The results are deduplicated and returned as a single list of documents.
|
||||
|
||||
### Usage example
|
||||
|
||||
```python
|
||||
from haystack import Document
|
||||
from haystack.document_stores.in_memory import InMemoryDocumentStore
|
||||
from haystack.document_stores.types import DuplicatePolicy
|
||||
from haystack.components.retrievers import InMemoryBM25Retriever, InMemoryEmbeddingRetriever
|
||||
from haystack.components.retrievers import TextEmbeddingRetriever, MultiRetriever
|
||||
from haystack.components.embedders import SentenceTransformersTextEmbedder, SentenceTransformersDocumentEmbedder
|
||||
from haystack.components.writers import DocumentWriter
|
||||
|
||||
documents = [
|
||||
Document(content="Renewable energy is energy that is collected from renewable resources."),
|
||||
Document(content="Solar energy is a type of green energy that is harnessed from the sun."),
|
||||
Document(content="Wind energy is another type of green energy that is generated by wind turbines."),
|
||||
]
|
||||
|
||||
# Populate the document store
|
||||
doc_store = InMemoryDocumentStore()
|
||||
doc_embedder = SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2")
|
||||
doc_writer = DocumentWriter(document_store=doc_store, policy=DuplicatePolicy.SKIP)
|
||||
doc_writer.run(documents=doc_embedder.run(documents)["documents"])
|
||||
|
||||
# Run the multi-retriever with all retrievers
|
||||
retriever = MultiRetriever(
|
||||
retrievers={
|
||||
"bm25": InMemoryBM25Retriever(document_store=doc_store),
|
||||
"embedding": TextEmbeddingRetriever(
|
||||
retriever=InMemoryEmbeddingRetriever(document_store=doc_store),
|
||||
text_embedder=SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"),
|
||||
),
|
||||
},
|
||||
top_k=3,
|
||||
)
|
||||
|
||||
# Run all retrievers
|
||||
result = retriever.run(query="green energy sources")
|
||||
|
||||
# Run only the BM25 retriever
|
||||
result = retriever.run(query="green energy sources", active_retrievers=["bm25"])
|
||||
|
||||
for doc in result["documents"]:
|
||||
print(doc.content)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
retrievers: dict[str, TextRetriever],
|
||||
filters: dict[str, Any] | None = None,
|
||||
top_k: int = 10,
|
||||
max_workers: int = 4,
|
||||
) -> None:
|
||||
"""
|
||||
Create the MultiRetriever component.
|
||||
|
||||
:param retrievers:
|
||||
A dictionary mapping names to text retrievers (implementing the `TextRetriever` protocol) to run in
|
||||
parallel.
|
||||
:param filters:
|
||||
A dictionary of filters to apply when retrieving documents.
|
||||
:param top_k:
|
||||
The maximum number of documents to return per retriever.
|
||||
:param max_workers:
|
||||
The maximum number of threads to use for parallel retrieval.
|
||||
"""
|
||||
self.retrievers = retrievers
|
||||
self.filters = filters
|
||||
self.top_k = top_k
|
||||
self.max_workers = max_workers
|
||||
self._is_warmed_up = False
|
||||
|
||||
def _resolve_retrievers(self, active_retrievers: list[str] | None) -> dict[str, TextRetriever]:
|
||||
"""
|
||||
Returns the subset of retrievers to run based on the active_retrievers list.
|
||||
|
||||
:param active_retrievers:
|
||||
A list of retriever names to run. If None, all retrievers are returned.
|
||||
|
||||
:returns:
|
||||
A dictionary of retriever names to retriever instances.
|
||||
|
||||
:raises ValueError:
|
||||
If any name in `active_retrievers` does not match a retriever name.
|
||||
"""
|
||||
if active_retrievers is None:
|
||||
return self.retrievers
|
||||
unknown = set(active_retrievers) - self.retrievers.keys()
|
||||
if unknown:
|
||||
raise ValueError(
|
||||
f"Unknown retriever name(s): {sorted(unknown)}. Available retrievers: {sorted(self.retrievers.keys())}"
|
||||
)
|
||||
return {name: self.retrievers[name] for name in active_retrievers}
|
||||
|
||||
def warm_up(self) -> None:
|
||||
"""
|
||||
Warm up the retrievers if any has a warm_up method.
|
||||
"""
|
||||
if self._is_warmed_up:
|
||||
return
|
||||
for retriever in self.retrievers.values():
|
||||
if hasattr(retriever, "warm_up") and callable(retriever.warm_up):
|
||||
retriever.warm_up()
|
||||
self._is_warmed_up = True
|
||||
|
||||
@component.output_types(documents=list[Document])
|
||||
def run(
|
||||
self,
|
||||
query: str,
|
||||
filters: dict[str, Any] | None = None,
|
||||
top_k: int | None = None,
|
||||
*,
|
||||
active_retrievers: list[str] | None = None,
|
||||
) -> dict[str, list[Document]]:
|
||||
"""
|
||||
Runs retrievers in parallel on the given query and returns deduplicated results.
|
||||
|
||||
:param query:
|
||||
The query to run the retrievers on.
|
||||
:param filters:
|
||||
Filters to apply. Defaults to the value set at initialization.
|
||||
:param top_k:
|
||||
Maximum documents to return per retriever. Defaults to the value set at initialization.
|
||||
:param active_retrievers:
|
||||
Names of retrievers to run. Defaults to all. Must match keys in the `retrievers` dictionary.
|
||||
|
||||
:returns:
|
||||
A dictionary with the keys:
|
||||
- "documents": A deduplicated list of retrieved documents.
|
||||
|
||||
:raises ValueError:
|
||||
If any name in `active_retrievers` does not match a retriever name.
|
||||
"""
|
||||
if not self._is_warmed_up:
|
||||
self.warm_up()
|
||||
|
||||
resolved_top_k = top_k if top_k is not None else self.top_k
|
||||
resolved_filters = filters if filters is not None else self.filters
|
||||
|
||||
retrievers_to_run = self._resolve_retrievers(active_retrievers)
|
||||
|
||||
all_documents: list[Document] = []
|
||||
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||||
future_to_name = {
|
||||
executor.submit(retriever.run, query=query, filters=resolved_filters, top_k=resolved_top_k): name
|
||||
for name, retriever in retrievers_to_run.items()
|
||||
}
|
||||
for future in as_completed(future_to_name):
|
||||
name = future_to_name[future]
|
||||
try:
|
||||
all_documents.extend(future.result().get("documents", []))
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Retriever '{name}' failed: {e}") from e
|
||||
|
||||
return {"documents": _deduplicate_documents(all_documents)}
|
||||
|
||||
@component.output_types(documents=list[Document])
|
||||
async def run_async(
|
||||
self,
|
||||
query: str,
|
||||
filters: dict[str, Any] | None = None,
|
||||
top_k: int | None = None,
|
||||
*,
|
||||
active_retrievers: list[str] | None = None,
|
||||
) -> dict[str, list[Document]]:
|
||||
"""
|
||||
Runs retrievers concurrently on the given query and returns deduplicated results.
|
||||
|
||||
Uses each retriever's `run_async` method if available, otherwise runs `run` in a thread executor.
|
||||
|
||||
:param query:
|
||||
The query to run the retrievers on.
|
||||
:param filters:
|
||||
Filters to apply. Defaults to the value set at initialization.
|
||||
:param top_k:
|
||||
Maximum documents to return per retriever. Defaults to the value set at initialization.
|
||||
:param active_retrievers:
|
||||
Names of retrievers to run. Defaults to all. Must match keys in the `retrievers` dictionary.
|
||||
|
||||
:returns:
|
||||
A dictionary with the keys:
|
||||
- "documents": A deduplicated list of retrieved documents.
|
||||
|
||||
:raises ValueError:
|
||||
If any name in `active_retrievers` does not match a retriever name.
|
||||
"""
|
||||
if not self._is_warmed_up:
|
||||
self.warm_up()
|
||||
|
||||
resolved_top_k = top_k if top_k is not None else self.top_k
|
||||
resolved_filters = filters if filters is not None else self.filters
|
||||
|
||||
retrievers_to_run = self._resolve_retrievers(active_retrievers)
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
async def _run_one(name: str, retriever: TextRetriever) -> list[Document]:
|
||||
try:
|
||||
if hasattr(retriever, "run_async") and callable(retriever.run_async):
|
||||
result = await retriever.run_async(query=query, filters=resolved_filters, top_k=resolved_top_k)
|
||||
else:
|
||||
result = await loop.run_in_executor(
|
||||
None, lambda: retriever.run(query=query, filters=resolved_filters, top_k=resolved_top_k)
|
||||
)
|
||||
return result.get("documents", [])
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Retriever '{name}' failed: {e}") from e
|
||||
|
||||
results = await asyncio.gather(*[_run_one(name, r) for name, r in retrievers_to_run.items()])
|
||||
|
||||
all_documents: list[Document] = []
|
||||
for docs in results:
|
||||
all_documents.extend(docs)
|
||||
|
||||
return {"documents": _deduplicate_documents(all_documents)}
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""
|
||||
Serializes the component to a dictionary.
|
||||
|
||||
:returns:
|
||||
Dictionary with serialized data.
|
||||
"""
|
||||
return default_to_dict(
|
||||
self,
|
||||
retrievers={name: component_to_dict(obj=r, name=name) for name, r in self.retrievers.items()},
|
||||
filters=self.filters,
|
||||
top_k=self.top_k,
|
||||
max_workers=self.max_workers,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "MultiRetriever":
|
||||
"""
|
||||
Creates an instance of the component from a dictionary.
|
||||
|
||||
:param data:
|
||||
Dictionary with the data to create the component.
|
||||
"""
|
||||
retrievers_data = data.get("init_parameters", {}).get("retrievers", {})
|
||||
if retrievers_data:
|
||||
retrievers = {}
|
||||
for name, retriever_data in retrievers_data.items():
|
||||
try:
|
||||
imported_class = import_class_by_name(retriever_data["type"])
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
f"Could not import class {retriever_data['type']} for retriever '{name}'. Error: {str(e)}"
|
||||
) from e
|
||||
retrievers[name] = component_from_dict(cls=imported_class, data=retriever_data, name=name)
|
||||
data["init_parameters"]["retrievers"] = retrievers
|
||||
return default_from_dict(cls, data)
|
||||
@@ -0,0 +1,129 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any
|
||||
|
||||
from haystack import Document, component, default_from_dict, default_to_dict
|
||||
from haystack.components.embedders.types.protocol import TextEmbedder
|
||||
from haystack.components.retrievers.types import EmbeddingRetriever
|
||||
from haystack.core.serialization import component_to_dict
|
||||
|
||||
|
||||
@component
|
||||
class TextEmbeddingRetriever:
|
||||
"""
|
||||
A component that retrieves documents using a query with an embedding-based retriever.
|
||||
|
||||
This component takes a text query, converts it to an embedding using a text embedder, and then uses an
|
||||
embedding-based retriever to find relevant documents.
|
||||
The results are sorted by relevance score.
|
||||
|
||||
### Usage example
|
||||
|
||||
```python
|
||||
from haystack import Document
|
||||
from haystack.document_stores.in_memory import InMemoryDocumentStore
|
||||
from haystack.document_stores.types import DuplicatePolicy
|
||||
from haystack.components.embedders import SentenceTransformersTextEmbedder, SentenceTransformersDocumentEmbedder
|
||||
from haystack.components.retrievers import InMemoryEmbeddingRetriever, TextEmbeddingRetriever
|
||||
from haystack.components.writers import DocumentWriter
|
||||
|
||||
documents = [
|
||||
Document(content="Renewable energy is energy that is collected from renewable resources."),
|
||||
Document(content="Solar energy is a type of green energy that is harnessed from the sun."),
|
||||
Document(content="Wind energy is another type of green energy that is generated by wind turbines."),
|
||||
Document(content="Geothermal energy is heat that comes from the sub-surface of the earth."),
|
||||
Document(content="Biomass energy is produced from organic materials, such as plant and animal waste."),
|
||||
Document(content="Fossil fuels, such as coal, oil, and natural gas, are non-renewable energy sources."),
|
||||
]
|
||||
|
||||
# Populate the document store
|
||||
doc_store = InMemoryDocumentStore()
|
||||
doc_embedder = SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2")
|
||||
doc_writer = DocumentWriter(document_store=doc_store, policy=DuplicatePolicy.SKIP)
|
||||
documents = doc_embedder.run(documents)["documents"]
|
||||
doc_writer.run(documents=documents)
|
||||
|
||||
# Run the retriever
|
||||
in_memory_retriever = InMemoryEmbeddingRetriever(document_store=doc_store, top_k=1)
|
||||
text_embedder = SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2")
|
||||
retriever = TextEmbeddingRetriever(retriever=in_memory_retriever, text_embedder=text_embedder)
|
||||
result = retriever.run(query="Geothermal energy")
|
||||
|
||||
for doc in result["documents"]:
|
||||
print(f"Content: {doc.content}, Score: {doc.score}")
|
||||
# >> Content: Geothermal energy is heat that comes from the sub-surface of the earth., Score: 0.8509603046266574
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, *, retriever: EmbeddingRetriever, text_embedder: TextEmbedder) -> None:
|
||||
"""
|
||||
Initialize TextEmbeddingRetriever.
|
||||
|
||||
:param retriever: The embedding-based retriever to use for document retrieval.
|
||||
:param text_embedder: The text embedder to convert a text query to an embedding.
|
||||
"""
|
||||
self.retriever = retriever
|
||||
self.text_embedder = text_embedder
|
||||
self._is_warmed_up = False
|
||||
|
||||
def warm_up(self) -> None:
|
||||
"""
|
||||
Warm up the text embedder and the retriever if any has a warm_up method.
|
||||
"""
|
||||
if not self._is_warmed_up:
|
||||
if hasattr(self.text_embedder, "warm_up") and callable(self.text_embedder.warm_up):
|
||||
self.text_embedder.warm_up()
|
||||
if hasattr(self.retriever, "warm_up") and callable(self.retriever.warm_up):
|
||||
self.retriever.warm_up()
|
||||
self._is_warmed_up = True
|
||||
|
||||
@component.output_types(documents=list[Document])
|
||||
def run(
|
||||
self, query: str, filters: dict[str, Any] | None = None, top_k: int | None = None
|
||||
) -> dict[str, list[Document]]:
|
||||
"""
|
||||
Retrieve documents using a single query.
|
||||
|
||||
:param query: The query to retrieve documents for.
|
||||
:param filters: A dictionary of filters to apply when retrieving documents.
|
||||
:param top_k: The maximum number of documents to return.
|
||||
:returns:
|
||||
A dictionary containing:
|
||||
- `documents`: List of retrieved documents sorted by relevance score.
|
||||
"""
|
||||
if not self._is_warmed_up:
|
||||
self.warm_up()
|
||||
|
||||
embedding_result = self.text_embedder.run(text=query)
|
||||
result = self.retriever.run(query_embedding=embedding_result["embedding"], filters=filters, top_k=top_k)
|
||||
docs: list[Document] = result["documents"]
|
||||
|
||||
# sort
|
||||
docs.sort(key=lambda x: x.score or 0.0, reverse=True)
|
||||
return {"documents": docs}
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""
|
||||
Serializes the component to a dictionary.
|
||||
|
||||
:returns:
|
||||
A dictionary representing the serialized component.
|
||||
"""
|
||||
return default_to_dict(
|
||||
self,
|
||||
retriever=component_to_dict(obj=self.retriever, name="retriever"),
|
||||
text_embedder=component_to_dict(obj=self.text_embedder, name="text_embedder"),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "TextEmbeddingRetriever":
|
||||
"""
|
||||
Deserializes the component from a dictionary.
|
||||
|
||||
:param data: The dictionary to deserialize from.
|
||||
:returns:
|
||||
The deserialized component.
|
||||
"""
|
||||
return default_from_dict(cls, data)
|
||||
@@ -0,0 +1,47 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import functools
|
||||
import warnings
|
||||
from typing import Any, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def _experimental(cls: type[T]) -> type[T]:
|
||||
"""
|
||||
Class decorator that marks a Haystack component as experimental.
|
||||
|
||||
Components decorated with @experimental are subject to breaking changes
|
||||
or removal in future releases without prior deprecation notice.
|
||||
|
||||
## Usage example
|
||||
|
||||
@_experimental
|
||||
@component
|
||||
class MyComponent:
|
||||
...
|
||||
"""
|
||||
# getattr/setattr are intentional here: direct attribute access (cls.__init__, cls.__init__ = ...)
|
||||
# triggers mypy [misc] and [attr-defined] errors because T is an unbound TypeVar.
|
||||
# noqa comments suppress ruff B009/B010 which would auto-revert these back to direct access.
|
||||
original_init: Any = getattr(cls, "__init__") # noqa: B009
|
||||
|
||||
@functools.wraps(original_init)
|
||||
def new_init(self: Any, *args: Any, **kwargs: Any) -> None:
|
||||
warnings.warn(
|
||||
f"'{cls.__name__}' is an experimental component and may change or be removed "
|
||||
"in future releases without prior deprecation notice. ",
|
||||
ExperimentalWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
original_init(self, *args, **kwargs)
|
||||
|
||||
setattr(cls, "__init__", new_init) # noqa: B010
|
||||
setattr(cls, "__experimental__", True) # noqa: B010
|
||||
return cls
|
||||
|
||||
|
||||
class ExperimentalWarning(UserWarning):
|
||||
"""Warning emitted when an experimental Haystack component is instantiated."""
|
||||
@@ -2,7 +2,7 @@ loaders:
|
||||
- search_path: [../haystack/components/retrievers]
|
||||
modules: ["auto_merging_retriever", "filter_retriever", "in_memory/bm25_retriever",
|
||||
"in_memory/embedding_retriever", "multi_query_embedding_retriever", "multi_query_text_retriever",
|
||||
"sentence_window_retriever"]
|
||||
"multi_retriever", "text_embedding_retriever", "sentence_window_retriever"]
|
||||
processors:
|
||||
- type: filter
|
||||
documented_only: true
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
---
|
||||
features:
|
||||
- |
|
||||
Added two new retriever components: ``MultiRetriever`` and ``TextEmbeddingRetriever``.
|
||||
|
||||
``MultiRetriever`` is marked as experimental and may change or be removed in future releases without prior deprecation notice.
|
||||
An ``ExperimentalWarning`` is printed when initializing this component.
|
||||
|
||||
``MultiRetriever`` combines multiple text retrievers into a single component.
|
||||
All text retrievers are queried in parallel and their results are deduplicated before being returned.
|
||||
Use the ``active_retrievers`` parameter to enable or disable specific retrievers at runtime.
|
||||
|
||||
``TextEmbeddingRetriever`` wraps an embedding-based retriever together with a text embedder into a single
|
||||
component that implements the ``TextRetriever`` protocol, making it compatible with ``MultiRetriever``.
|
||||
@@ -0,0 +1,455 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import ANY
|
||||
|
||||
import pytest
|
||||
|
||||
from haystack import Document, component
|
||||
from haystack.components.embedders import SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder
|
||||
from haystack.components.retrievers import (
|
||||
InMemoryBM25Retriever,
|
||||
InMemoryEmbeddingRetriever,
|
||||
MultiRetriever,
|
||||
TextEmbeddingRetriever,
|
||||
)
|
||||
from haystack.components.writers import DocumentWriter
|
||||
from haystack.document_stores.in_memory import InMemoryDocumentStore
|
||||
from haystack.document_stores.types import DuplicatePolicy
|
||||
from haystack.utils.experimental import ExperimentalWarning
|
||||
|
||||
pytestmark = pytest.mark.filterwarnings("ignore::haystack.utils.experimental.ExperimentalWarning")
|
||||
|
||||
|
||||
@component
|
||||
class MockRetriever:
|
||||
def __init__(self, documents: list[Document] | None = None):
|
||||
self.documents = documents or []
|
||||
|
||||
@component.output_types(documents=list[Document])
|
||||
def run(self, query: str, filters: dict[str, Any] | None = None, top_k: int | None = None):
|
||||
return {"documents": self.documents}
|
||||
|
||||
|
||||
@component
|
||||
class FailingRetriever:
|
||||
@component.output_types(documents=list[Document])
|
||||
def run(self, query: str, filters: dict[str, Any] | None = None, top_k: int | None = None):
|
||||
raise RuntimeError("connection error")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_documents():
|
||||
return [
|
||||
Document(
|
||||
content="Renewable energy is energy that is collected from renewable resources.",
|
||||
meta={"category": "renewable"},
|
||||
id="doc1",
|
||||
),
|
||||
Document(
|
||||
content="Solar energy is a type of green energy that is harnessed from the sun.",
|
||||
meta={"category": "solar"},
|
||||
id="doc2",
|
||||
),
|
||||
Document(
|
||||
content="Wind energy is another type of green energy that is generated by wind turbines.",
|
||||
meta={"category": "wind"},
|
||||
id="doc3",
|
||||
),
|
||||
Document(
|
||||
content="Geothermal energy is heat that comes from the sub-surface of the earth.",
|
||||
meta={"category": "geothermal"},
|
||||
id="doc4",
|
||||
),
|
||||
Document(
|
||||
content="Fossil fuels, such as coal, oil, and natural gas, are non-renewable energy sources.",
|
||||
meta={"category": "fossil"},
|
||||
id="doc5",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def document_store_with_embeddings(sample_documents):
|
||||
"""Create a document store populated with embedded documents."""
|
||||
document_store = InMemoryDocumentStore()
|
||||
doc_embedder = SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2")
|
||||
doc_writer = DocumentWriter(document_store=document_store, policy=DuplicatePolicy.SKIP)
|
||||
embedded_docs = doc_embedder.run(sample_documents)["documents"]
|
||||
doc_writer.run(documents=embedded_docs)
|
||||
return document_store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def bm25_retriever(document_store_with_embeddings):
|
||||
return InMemoryBM25Retriever(document_store=document_store_with_embeddings)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def embedding_retriever(document_store_with_embeddings):
|
||||
return TextEmbeddingRetriever(
|
||||
retriever=InMemoryEmbeddingRetriever(document_store=document_store_with_embeddings),
|
||||
text_embedder=SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"),
|
||||
)
|
||||
|
||||
|
||||
class TestMultiRetriever:
|
||||
def test_init_default_parameters(self):
|
||||
retrievers = {"mock": MockRetriever()}
|
||||
retriever = MultiRetriever(retrievers=retrievers)
|
||||
assert retriever.retrievers == retrievers
|
||||
assert retriever.filters is None
|
||||
assert retriever.top_k == 10
|
||||
assert retriever.max_workers == 4
|
||||
|
||||
def test_init_custom_parameters(self):
|
||||
retrievers = {"mock": MockRetriever()}
|
||||
retriever = MultiRetriever(retrievers=retrievers, filters={"field": "meta.category"}, top_k=5, max_workers=2)
|
||||
assert retriever.retrievers == retrievers
|
||||
assert retriever.filters == {"field": "meta.category"}
|
||||
assert retriever.top_k == 5
|
||||
assert retriever.max_workers == 2
|
||||
|
||||
def test_run_with_empty_document_store(self):
|
||||
retriever = MultiRetriever(retrievers={"mock": MockRetriever()})
|
||||
result = retriever.run(query="green energy")
|
||||
assert "documents" in result
|
||||
assert result["documents"] == []
|
||||
|
||||
def test_run_combines_results_from_multiple_retrievers(self, sample_documents):
|
||||
retriever = MultiRetriever(
|
||||
retrievers={
|
||||
"a": MockRetriever(documents=[sample_documents[0]]),
|
||||
"b": MockRetriever(documents=[sample_documents[1]]),
|
||||
},
|
||||
max_workers=2,
|
||||
)
|
||||
result = retriever.run(query="energy")
|
||||
assert len(result["documents"]) == 2
|
||||
assert {doc.id for doc in result["documents"]} == {"doc1", "doc2"}
|
||||
|
||||
def test_run_deduplicates_results(self, sample_documents):
|
||||
retriever = MultiRetriever(
|
||||
retrievers={
|
||||
"c": MockRetriever(documents=[sample_documents[0], sample_documents[1]]),
|
||||
"d": MockRetriever(documents=[sample_documents[0]]),
|
||||
},
|
||||
max_workers=2,
|
||||
)
|
||||
result = retriever.run(query="energy")
|
||||
assert len(result["documents"]) == 2
|
||||
ids = [doc.id for doc in result["documents"]]
|
||||
assert ids.count("doc1") == 1
|
||||
|
||||
def test_run_resolves_filters_and_top_k(self):
|
||||
received: dict = {}
|
||||
|
||||
@component
|
||||
class CapturingRetriever:
|
||||
@component.output_types(documents=list[Document])
|
||||
def run(self, query: str, filters: dict[str, Any] | None = None, top_k: int | None = None):
|
||||
received["filters"] = filters
|
||||
received["top_k"] = top_k
|
||||
return {"documents": []}
|
||||
|
||||
retriever = MultiRetriever(
|
||||
retrievers={"capturing": CapturingRetriever()}, filters={"field": "meta.category"}, top_k=5
|
||||
)
|
||||
|
||||
# Should use init-time values when not overridden
|
||||
retriever.run(query="energy")
|
||||
assert received["filters"] == {"field": "meta.category"}
|
||||
assert received["top_k"] == 5
|
||||
|
||||
# Should prefer run-time values when provided
|
||||
retriever.run(query="energy", filters={"field": "meta.other"}, top_k=2)
|
||||
assert received["filters"] == {"field": "meta.other"}
|
||||
assert received["top_k"] == 2
|
||||
|
||||
def test_run_with_active_retrievers(self, sample_documents):
|
||||
retriever = MultiRetriever(
|
||||
retrievers={"a": MockRetriever([sample_documents[0]]), "b": MockRetriever([sample_documents[1]])}
|
||||
)
|
||||
# Only run retriever "a"
|
||||
result = retriever.run(query="energy", active_retrievers=["a"])
|
||||
assert len(result["documents"]) == 1
|
||||
assert result["documents"][0].id == "doc1"
|
||||
|
||||
def test_run_with_unknown_active_retriever_raises(self):
|
||||
retriever = MultiRetriever(retrievers={"mock": MockRetriever()})
|
||||
with pytest.raises(ValueError, match="Unknown retriever name"):
|
||||
retriever.run(query="energy", active_retrievers=["nonexistent"])
|
||||
|
||||
def test_run_retriever_failure_raises_with_name(self):
|
||||
retriever = MultiRetriever(retrievers={"failing": FailingRetriever()})
|
||||
with pytest.raises(RuntimeError, match="Retriever 'failing' failed"):
|
||||
retriever.run(query="energy")
|
||||
|
||||
def test_to_dict(self):
|
||||
retriever = MultiRetriever(
|
||||
retrievers={"bm25": InMemoryBM25Retriever(document_store=InMemoryDocumentStore())},
|
||||
filters=None,
|
||||
top_k=5,
|
||||
max_workers=2,
|
||||
)
|
||||
result = retriever.to_dict()
|
||||
assert result == {
|
||||
"type": "haystack.components.retrievers.multi_retriever.MultiRetriever",
|
||||
"init_parameters": {
|
||||
"retrievers": {
|
||||
"bm25": {
|
||||
"type": "haystack.components.retrievers.in_memory.bm25_retriever.InMemoryBM25Retriever",
|
||||
"init_parameters": {
|
||||
"document_store": {
|
||||
"type": "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore",
|
||||
"init_parameters": {
|
||||
"bm25_tokenization_regex": "(?u)\\b\\w+\\b",
|
||||
"bm25_algorithm": "BM25L",
|
||||
"bm25_parameters": {},
|
||||
"embedding_similarity_function": "dot_product",
|
||||
"index": ANY,
|
||||
"return_embedding": True,
|
||||
},
|
||||
},
|
||||
"filters": None,
|
||||
"top_k": 10,
|
||||
"scale_score": False,
|
||||
"filter_policy": "replace",
|
||||
},
|
||||
}
|
||||
},
|
||||
"filters": None,
|
||||
"top_k": 5,
|
||||
"max_workers": 2,
|
||||
},
|
||||
}
|
||||
|
||||
def test_from_dict(self):
|
||||
data = {
|
||||
"type": "haystack.components.retrievers.multi_retriever.MultiRetriever",
|
||||
"init_parameters": {
|
||||
"retrievers": {
|
||||
"bm25": {
|
||||
"type": "haystack.components.retrievers.in_memory.bm25_retriever.InMemoryBM25Retriever",
|
||||
"init_parameters": {
|
||||
"document_store": {
|
||||
"type": "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore",
|
||||
"init_parameters": {
|
||||
"bm25_tokenization_regex": "(?u)\\b\\w\\w+\\b",
|
||||
"bm25_algorithm": "BM25L",
|
||||
"bm25_parameters": {},
|
||||
"embedding_similarity_function": "dot_product",
|
||||
"index": "4bb5369d-779f-487b-9c16-3c40f503438b",
|
||||
"return_embedding": True,
|
||||
},
|
||||
},
|
||||
"filters": None,
|
||||
"top_k": 10,
|
||||
"scale_score": False,
|
||||
"filter_policy": "replace",
|
||||
},
|
||||
}
|
||||
},
|
||||
"filters": None,
|
||||
"top_k": 5,
|
||||
"max_workers": 2,
|
||||
},
|
||||
}
|
||||
result = MultiRetriever.from_dict(data)
|
||||
assert isinstance(result, MultiRetriever)
|
||||
assert len(result.retrievers) == 1
|
||||
assert "bm25" in result.retrievers
|
||||
assert isinstance(result.retrievers["bm25"], InMemoryBM25Retriever)
|
||||
assert result.top_k == 5
|
||||
assert result.max_workers == 2
|
||||
|
||||
def test_from_dict_with_no_retrievers(self):
|
||||
data = {
|
||||
"type": "haystack.components.retrievers.multi_retriever.MultiRetriever",
|
||||
"init_parameters": {"retrievers": {}, "filters": None, "top_k": 10, "max_workers": 4},
|
||||
}
|
||||
result = MultiRetriever.from_dict(data)
|
||||
assert isinstance(result, MultiRetriever)
|
||||
assert result.retrievers == {}
|
||||
|
||||
def test_from_dict_with_unknown_retriever_type_raises(self):
|
||||
data = {
|
||||
"type": "haystack.components.retrievers.multi_retriever.MultiRetriever",
|
||||
"init_parameters": {
|
||||
"retrievers": {
|
||||
"bad": {"type": "haystack.components.retrievers.NonExistentRetriever", "init_parameters": {}}
|
||||
},
|
||||
"filters": None,
|
||||
"top_k": 10,
|
||||
"max_workers": 4,
|
||||
},
|
||||
}
|
||||
with pytest.raises(ImportError, match="Could not import class"):
|
||||
MultiRetriever.from_dict(data)
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.slow
|
||||
def test_run_with_filters(self, del_hf_env_vars, bm25_retriever, embedding_retriever):
|
||||
retriever = MultiRetriever(retrievers={"bm25": bm25_retriever, "embedding": embedding_retriever})
|
||||
result = retriever.run(query="energy", filters={"field": "meta.category", "operator": "==", "value": "solar"})
|
||||
assert len(result["documents"]) == 1
|
||||
assert result["documents"][0].meta["category"] == "solar"
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.slow
|
||||
def test_run_with_top_k(self, del_hf_env_vars, bm25_retriever, embedding_retriever):
|
||||
retriever = MultiRetriever(retrievers={"bm25": bm25_retriever, "embedding": embedding_retriever})
|
||||
result = retriever.run(query="energy", top_k=2)
|
||||
assert len(result["documents"]) == 2
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.slow
|
||||
def test_run_with_active_retrievers_integration(self, del_hf_env_vars, bm25_retriever, embedding_retriever):
|
||||
retriever = MultiRetriever(retrievers={"bm25": bm25_retriever, "embedding": embedding_retriever})
|
||||
result_bm25_active = retriever.run(query="energy", active_retrievers=["bm25"])
|
||||
result_bm25 = bm25_retriever.run(query="energy")
|
||||
assert result_bm25_active == result_bm25
|
||||
|
||||
|
||||
class TestMultiRetrieverAsync:
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async_with_empty_results(self):
|
||||
retriever = MultiRetriever(retrievers={"mock": MockRetriever()})
|
||||
result = await retriever.run_async(query="green energy")
|
||||
assert "documents" in result
|
||||
assert result["documents"] == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async_combines_results_from_multiple_retrievers(self, sample_documents):
|
||||
retriever = MultiRetriever(
|
||||
retrievers={
|
||||
"a": MockRetriever(documents=[sample_documents[0]]),
|
||||
"b": MockRetriever(documents=[sample_documents[1]]),
|
||||
}
|
||||
)
|
||||
result = await retriever.run_async(query="energy")
|
||||
assert len(result["documents"]) == 2
|
||||
assert {doc.id for doc in result["documents"]} == {"doc1", "doc2"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async_deduplicates_results(self, sample_documents):
|
||||
retriever = MultiRetriever(
|
||||
retrievers={
|
||||
"c": MockRetriever(documents=[sample_documents[0], sample_documents[1]]),
|
||||
"d": MockRetriever(documents=[sample_documents[0]]),
|
||||
}
|
||||
)
|
||||
result = await retriever.run_async(query="energy")
|
||||
assert len(result["documents"]) == 2
|
||||
assert [doc.id for doc in result["documents"]].count("doc1") == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async_resolves_filters_and_top_k(self):
|
||||
received: dict = {}
|
||||
|
||||
@component
|
||||
class CapturingRetriever:
|
||||
@component.output_types(documents=list[Document])
|
||||
def run(self, query: str, filters: dict[str, Any] | None = None, top_k: int | None = None):
|
||||
received["filters"] = filters
|
||||
received["top_k"] = top_k
|
||||
return {"documents": []}
|
||||
|
||||
retriever = MultiRetriever(
|
||||
retrievers={"capturing": CapturingRetriever()}, filters={"field": "meta.category"}, top_k=5
|
||||
)
|
||||
|
||||
await retriever.run_async(query="energy")
|
||||
assert received["filters"] == {"field": "meta.category"}
|
||||
assert received["top_k"] == 5
|
||||
|
||||
await retriever.run_async(query="energy", filters={"field": "meta.other"}, top_k=2)
|
||||
assert received["filters"] == {"field": "meta.other"}
|
||||
assert received["top_k"] == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async_with_active_retrievers(self, sample_documents):
|
||||
retriever = MultiRetriever(
|
||||
retrievers={"a": MockRetriever([sample_documents[0]]), "b": MockRetriever([sample_documents[1]])}
|
||||
)
|
||||
result = await retriever.run_async(query="energy", active_retrievers=["a"])
|
||||
assert len(result["documents"]) == 1
|
||||
assert result["documents"][0].id == "doc1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async_with_unknown_active_retriever_raises(self):
|
||||
retriever = MultiRetriever(retrievers={"mock": MockRetriever()})
|
||||
with pytest.raises(ValueError, match="Unknown retriever name"):
|
||||
await retriever.run_async(query="energy", active_retrievers=["nonexistent"])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async_retriever_failure_raises_with_name(self):
|
||||
retriever = MultiRetriever(retrievers={"failing": FailingRetriever()})
|
||||
with pytest.raises(RuntimeError, match="Retriever 'failing' failed"):
|
||||
await retriever.run_async(query="energy")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async_uses_run_async_on_retriever_if_available(self):
|
||||
@component
|
||||
class AsyncCapableRetriever:
|
||||
def __init__(self):
|
||||
self.used_async = False
|
||||
|
||||
@component.output_types(documents=list[Document])
|
||||
def run(self, query: str, filters: dict[str, Any] | None = None, top_k: int | None = None):
|
||||
return {"documents": []}
|
||||
|
||||
@component.output_types(documents=list[Document])
|
||||
async def run_async(self, query: str, filters: dict[str, Any] | None = None, top_k: int | None = None):
|
||||
self.used_async = True
|
||||
return {"documents": [Document(content="async result", id="async1")]}
|
||||
|
||||
inner = AsyncCapableRetriever()
|
||||
retriever = MultiRetriever(retrievers={"async_capable": inner})
|
||||
result = await retriever.run_async(query="energy")
|
||||
assert inner.used_async is True
|
||||
assert len(result["documents"]) == 1
|
||||
assert result["documents"][0].id == "async1"
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async_with_filters(self, del_hf_env_vars, bm25_retriever, embedding_retriever):
|
||||
retriever = MultiRetriever(retrievers={"bm25": bm25_retriever, "embedding": embedding_retriever})
|
||||
result = await retriever.run_async(
|
||||
query="energy", filters={"field": "meta.category", "operator": "==", "value": "solar"}
|
||||
)
|
||||
assert len(result["documents"]) == 1
|
||||
assert result["documents"][0].meta["category"] == "solar"
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async_with_top_k(self, del_hf_env_vars, bm25_retriever, embedding_retriever):
|
||||
retriever = MultiRetriever(retrievers={"bm25": bm25_retriever, "embedding": embedding_retriever})
|
||||
result = await retriever.run_async(query="energy", top_k=2)
|
||||
assert len(result["documents"]) == 2
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async_with_active_retrievers_integration(
|
||||
self, del_hf_env_vars, bm25_retriever, embedding_retriever
|
||||
):
|
||||
retriever = MultiRetriever(retrievers={"bm25": bm25_retriever, "embedding": embedding_retriever})
|
||||
result_bm25_active = await retriever.run_async(query="energy", active_retrievers=["bm25"])
|
||||
result_bm25 = await bm25_retriever.run_async(query="energy")
|
||||
assert result_bm25_active == result_bm25
|
||||
|
||||
|
||||
class TestMultiRetrieverExperimental:
|
||||
@pytest.mark.filterwarnings("always::haystack.utils.experimental.ExperimentalWarning")
|
||||
def test_emits_experimental_warning_on_init(self):
|
||||
with pytest.warns(ExperimentalWarning, match="MultiRetriever.*experimental"):
|
||||
MultiRetriever(retrievers={"mock": MockRetriever()})
|
||||
|
||||
@pytest.mark.filterwarnings("always::haystack.utils.experimental.ExperimentalWarning")
|
||||
def test_experimental_attribute_is_set(self):
|
||||
assert getattr(MultiRetriever, "__experimental__", False) is True
|
||||
@@ -0,0 +1,193 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import ANY
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from haystack import Document, component
|
||||
from haystack.components.embedders import SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder
|
||||
from haystack.components.retrievers import InMemoryEmbeddingRetriever, TextEmbeddingRetriever
|
||||
from haystack.components.writers import DocumentWriter
|
||||
from haystack.document_stores.in_memory import InMemoryDocumentStore
|
||||
from haystack.document_stores.types import DuplicatePolicy
|
||||
|
||||
|
||||
@component
|
||||
class MockTextEmbedder:
|
||||
@component.output_types(embedding=list[float])
|
||||
def run(self, text: str) -> dict[str, list[float]]:
|
||||
return {"embedding": np.ones(384).tolist()}
|
||||
|
||||
|
||||
class TestTextEmbeddingRetriever:
|
||||
@pytest.fixture
|
||||
def sample_documents(self):
|
||||
return [
|
||||
Document(content="Renewable energy is energy that is collected from renewable resources."),
|
||||
Document(content="Solar energy is a type of green energy that is harnessed from the sun."),
|
||||
Document(content="Wind energy is another type of green energy that is generated by wind turbines."),
|
||||
Document(content="Geothermal energy is heat that comes from the sub-surface of the earth."),
|
||||
Document(content="Fossil fuels, such as coal, oil, and natural gas, are non-renewable energy sources."),
|
||||
]
|
||||
|
||||
@pytest.fixture
|
||||
def document_store_with_embeddings(self, sample_documents):
|
||||
"""Create a document store populated with embedded documents."""
|
||||
document_store = InMemoryDocumentStore()
|
||||
doc_embedder = SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2")
|
||||
doc_writer = DocumentWriter(document_store=document_store, policy=DuplicatePolicy.SKIP)
|
||||
|
||||
embedded_docs = doc_embedder.run(sample_documents)["documents"]
|
||||
doc_writer.run(documents=embedded_docs)
|
||||
return document_store
|
||||
|
||||
def test_init(self):
|
||||
embedding_retriever = InMemoryEmbeddingRetriever(document_store=InMemoryDocumentStore())
|
||||
text_embedder = MockTextEmbedder()
|
||||
retriever = TextEmbeddingRetriever(retriever=embedding_retriever, text_embedder=text_embedder)
|
||||
assert retriever.retriever == embedding_retriever
|
||||
assert retriever.text_embedder == text_embedder
|
||||
|
||||
def test_run_with_empty_document_store(self):
|
||||
retriever = TextEmbeddingRetriever(
|
||||
retriever=InMemoryEmbeddingRetriever(document_store=InMemoryDocumentStore()),
|
||||
text_embedder=MockTextEmbedder(),
|
||||
)
|
||||
result = retriever.run(query="green energy")
|
||||
assert "documents" in result
|
||||
assert result["documents"] == []
|
||||
|
||||
def test_run_returns_documents_sorted_by_score(self):
|
||||
doc_high = Document(content="Solar energy", id="doc1", score=0.9)
|
||||
doc_low = Document(content="Fossil fuels", id="doc2", score=0.3)
|
||||
doc_mid = Document(content="Wind energy", id="doc3", score=0.6)
|
||||
|
||||
@component
|
||||
class MockRetriever:
|
||||
@component.output_types(documents=list[Document])
|
||||
def run(
|
||||
self, query_embedding: list[float], filters: dict[str, Any] | None = None, top_k: int | None = None
|
||||
) -> dict[str, list[Document]]:
|
||||
return {"documents": [doc_low, doc_high, doc_mid]}
|
||||
|
||||
retriever = TextEmbeddingRetriever(retriever=MockRetriever(), text_embedder=MockTextEmbedder())
|
||||
result = retriever.run(query="energy")
|
||||
|
||||
scores = [doc.score for doc in result["documents"]]
|
||||
assert scores == sorted(scores, reverse=True)
|
||||
|
||||
def test_to_dict(self):
|
||||
retriever = TextEmbeddingRetriever(
|
||||
retriever=InMemoryEmbeddingRetriever(document_store=InMemoryDocumentStore()),
|
||||
text_embedder=MockTextEmbedder(),
|
||||
)
|
||||
result = retriever.to_dict()
|
||||
assert result == {
|
||||
"type": "haystack.components.retrievers.text_embedding_retriever.TextEmbeddingRetriever",
|
||||
"init_parameters": {
|
||||
"retriever": {
|
||||
"type": "haystack.components.retrievers.in_memory.embedding_retriever.InMemoryEmbeddingRetriever",
|
||||
"init_parameters": {
|
||||
"document_store": {
|
||||
"type": "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore",
|
||||
"init_parameters": {
|
||||
"bm25_tokenization_regex": "(?u)\\b\\w+\\b",
|
||||
"bm25_algorithm": "BM25L",
|
||||
"bm25_parameters": {},
|
||||
"embedding_similarity_function": "dot_product",
|
||||
"index": ANY,
|
||||
"return_embedding": True,
|
||||
},
|
||||
},
|
||||
"filters": None,
|
||||
"top_k": 10,
|
||||
"scale_score": False,
|
||||
"return_embedding": False,
|
||||
"filter_policy": "replace",
|
||||
},
|
||||
},
|
||||
"text_embedder": {
|
||||
"type": "retrievers.test_text_embedding_retriever.MockTextEmbedder",
|
||||
"init_parameters": {},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def test_from_dict(self):
|
||||
data = {
|
||||
"type": "haystack.components.retrievers.text_embedding_retriever.TextEmbeddingRetriever",
|
||||
"init_parameters": {
|
||||
"retriever": {
|
||||
"type": "haystack.components.retrievers.in_memory.embedding_retriever.InMemoryEmbeddingRetriever",
|
||||
"init_parameters": {
|
||||
"document_store": {
|
||||
"type": "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore",
|
||||
"init_parameters": {
|
||||
"bm25_tokenization_regex": "(?u)\\b\\w\\w+\\b",
|
||||
"bm25_algorithm": "BM25L",
|
||||
"bm25_parameters": {},
|
||||
"embedding_similarity_function": "dot_product",
|
||||
"index": "4bb5369d-779f-487b-9c16-3c40f503438b",
|
||||
"return_embedding": True,
|
||||
},
|
||||
},
|
||||
"filters": None,
|
||||
"top_k": 10,
|
||||
"scale_score": False,
|
||||
"return_embedding": False,
|
||||
"filter_policy": "replace",
|
||||
},
|
||||
},
|
||||
"text_embedder": {
|
||||
"type": "haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder", # noqa E501
|
||||
"init_parameters": {
|
||||
"model": "sentence-transformers/all-MiniLM-L6-v2",
|
||||
"token": {"type": "env_var", "env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False},
|
||||
"prefix": "",
|
||||
"suffix": "",
|
||||
"batch_size": 32,
|
||||
"progress_bar": True,
|
||||
"normalize_embeddings": False,
|
||||
"trust_remote_code": False,
|
||||
"local_files_only": False,
|
||||
"truncate_dim": None,
|
||||
"model_kwargs": None,
|
||||
"tokenizer_kwargs": None,
|
||||
"config_kwargs": None,
|
||||
"precision": "float32",
|
||||
"encode_kwargs": None,
|
||||
"backend": "torch",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
result = TextEmbeddingRetriever.from_dict(data)
|
||||
assert isinstance(result, TextEmbeddingRetriever)
|
||||
assert isinstance(result.retriever, InMemoryEmbeddingRetriever)
|
||||
assert isinstance(result.text_embedder, SentenceTransformersTextEmbedder)
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.slow
|
||||
def test_run_with_filters(self, del_hf_env_vars, document_store_with_embeddings):
|
||||
retriever = TextEmbeddingRetriever(
|
||||
retriever=InMemoryEmbeddingRetriever(document_store=document_store_with_embeddings),
|
||||
text_embedder=SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"),
|
||||
)
|
||||
result = retriever.run(query="energy", filters={"field": "meta.category", "operator": "==", "value": "solar"})
|
||||
assert "documents" in result
|
||||
assert all(doc.meta.get("category") == "solar" for doc in result["documents"])
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.slow
|
||||
def test_run_with_top_k(self, del_hf_env_vars, document_store_with_embeddings):
|
||||
retriever = TextEmbeddingRetriever(
|
||||
retriever=InMemoryEmbeddingRetriever(document_store=document_store_with_embeddings),
|
||||
text_embedder=SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"),
|
||||
)
|
||||
result = retriever.run(query="energy", top_k=2)
|
||||
assert "documents" in result
|
||||
assert len(result["documents"]) <= 2
|
||||
@@ -0,0 +1,87 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
|
||||
from haystack import component
|
||||
from haystack.utils.experimental import ExperimentalWarning, _experimental
|
||||
|
||||
|
||||
class TestExperimentalDecorator:
|
||||
def test_emits_experimental_warning_on_init(self):
|
||||
@_experimental
|
||||
@component
|
||||
class MyComponent:
|
||||
@component.output_types(value=int)
|
||||
def run(self, value: int) -> dict:
|
||||
return {"value": value}
|
||||
|
||||
with pytest.warns(ExperimentalWarning):
|
||||
MyComponent()
|
||||
|
||||
def test_warning_message_contains_class_name(self):
|
||||
@_experimental
|
||||
@component
|
||||
class MyComponent:
|
||||
@component.output_types(value=int)
|
||||
def run(self, value: int) -> dict:
|
||||
return {"value": value}
|
||||
|
||||
with pytest.warns(ExperimentalWarning, match="MyComponent"):
|
||||
MyComponent()
|
||||
|
||||
def test_sets_experimental_attribute(self):
|
||||
@_experimental
|
||||
@component
|
||||
class MyComponent:
|
||||
@component.output_types(value=int)
|
||||
def run(self, value: int) -> dict:
|
||||
return {"value": value}
|
||||
|
||||
assert MyComponent.__experimental__ is True
|
||||
|
||||
def test_passes_args_and_kwargs_to_init(self):
|
||||
@_experimental
|
||||
@component
|
||||
class MyComponent:
|
||||
def __init__(self, value: int, label: str = "default"):
|
||||
self.value = value
|
||||
self.label = label
|
||||
|
||||
@component.output_types(value=int)
|
||||
def run(self, value: int) -> dict:
|
||||
return {"value": value}
|
||||
|
||||
with pytest.warns(ExperimentalWarning):
|
||||
instance = MyComponent(42, label="custom")
|
||||
|
||||
assert instance.value == 42
|
||||
assert instance.label == "custom"
|
||||
|
||||
def test_preserves_init_name(self):
|
||||
@_experimental
|
||||
@component
|
||||
class MyComponent:
|
||||
@component.output_types(value=int)
|
||||
def run(self, value: int) -> dict:
|
||||
return {"value": value}
|
||||
|
||||
assert MyComponent.__init__.__name__ == "__init__"
|
||||
|
||||
def test_experimental_warning_is_user_warning_subclass(self):
|
||||
assert issubclass(ExperimentalWarning, UserWarning)
|
||||
|
||||
def test_warning_emitted_on_every_instantiation(self):
|
||||
@_experimental
|
||||
@component
|
||||
class MyComponent:
|
||||
@component.output_types(value=int)
|
||||
def run(self, value: int) -> dict:
|
||||
return {"value": value}
|
||||
|
||||
with pytest.warns(ExperimentalWarning):
|
||||
MyComponent()
|
||||
|
||||
with pytest.warns(ExperimentalWarning):
|
||||
MyComponent()
|
||||
Reference in New Issue
Block a user