feat: Add new components TextEmbeddingRetriever and MultiRetriever (#10872)

This commit is contained in:
Sebastian Husch Lee
2026-04-29 09:29:01 +02:00
committed by GitHub
parent 02e0b42132
commit 2a4f104edf
9 changed files with 1214 additions and 1 deletions
@@ -11,9 +11,11 @@ _import_structure = {
"auto_merging_retriever": ["AutoMergingRetriever"],
"filter_retriever": ["FilterRetriever"],
"in_memory": ["InMemoryBM25Retriever", "InMemoryEmbeddingRetriever"],
"multi_retriever": ["MultiRetriever"],
"multi_query_embedding_retriever": ["MultiQueryEmbeddingRetriever"],
"multi_query_text_retriever": ["MultiQueryTextRetriever"],
"sentence_window_retriever": ["SentenceWindowRetriever"],
"text_embedding_retriever": ["TextEmbeddingRetriever"],
}
if TYPE_CHECKING:
@@ -23,7 +25,9 @@ if TYPE_CHECKING:
from .in_memory import InMemoryEmbeddingRetriever as InMemoryEmbeddingRetriever
from .multi_query_embedding_retriever import MultiQueryEmbeddingRetriever as MultiQueryEmbeddingRetriever
from .multi_query_text_retriever import MultiQueryTextRetriever as MultiQueryTextRetriever
from .multi_retriever import MultiRetriever as MultiRetriever
from .sentence_window_retriever import SentenceWindowRetriever as SentenceWindowRetriever
from .text_embedding_retriever import TextEmbeddingRetriever as TextEmbeddingRetriever
else:
sys.modules[__name__] = LazyImporter(name=__name__, module_file=__file__, import_structure=_import_structure)
@@ -0,0 +1,284 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
import asyncio
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any
from haystack import component, default_from_dict, default_to_dict
from haystack.components.retrievers.types.protocol import TextRetriever
from haystack.core.serialization import component_from_dict, component_to_dict, import_class_by_name
from haystack.dataclasses import Document
from haystack.utils.experimental import _experimental
from haystack.utils.misc import _deduplicate_documents
@_experimental
@component
class MultiRetriever:
"""
A component that accepts text retrievers and runs them in parallel, combining their results.
> **Note:** This component is experimental and may change or be removed in future releases without prior
deprecation notice.
All retrievers must implement the `TextRetriever` protocol. Use `TextEmbeddingRetriever` to wrap an
embedding-based retriever before passing it to this component.
Each retriever is queried concurrently using a thread pool.
The results are deduplicated and returned as a single list of documents.
### Usage example
```python
from haystack import Document
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.document_stores.types import DuplicatePolicy
from haystack.components.retrievers import InMemoryBM25Retriever, InMemoryEmbeddingRetriever
from haystack.components.retrievers import TextEmbeddingRetriever, MultiRetriever
from haystack.components.embedders import SentenceTransformersTextEmbedder, SentenceTransformersDocumentEmbedder
from haystack.components.writers import DocumentWriter
documents = [
Document(content="Renewable energy is energy that is collected from renewable resources."),
Document(content="Solar energy is a type of green energy that is harnessed from the sun."),
Document(content="Wind energy is another type of green energy that is generated by wind turbines."),
]
# Populate the document store
doc_store = InMemoryDocumentStore()
doc_embedder = SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2")
doc_writer = DocumentWriter(document_store=doc_store, policy=DuplicatePolicy.SKIP)
doc_writer.run(documents=doc_embedder.run(documents)["documents"])
# Run the multi-retriever with all retrievers
retriever = MultiRetriever(
retrievers={
"bm25": InMemoryBM25Retriever(document_store=doc_store),
"embedding": TextEmbeddingRetriever(
retriever=InMemoryEmbeddingRetriever(document_store=doc_store),
text_embedder=SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"),
),
},
top_k=3,
)
# Run all retrievers
result = retriever.run(query="green energy sources")
# Run only the BM25 retriever
result = retriever.run(query="green energy sources", active_retrievers=["bm25"])
for doc in result["documents"]:
print(doc.content)
```
"""
def __init__(
self,
*,
retrievers: dict[str, TextRetriever],
filters: dict[str, Any] | None = None,
top_k: int = 10,
max_workers: int = 4,
) -> None:
"""
Create the MultiRetriever component.
:param retrievers:
A dictionary mapping names to text retrievers (implementing the `TextRetriever` protocol) to run in
parallel.
:param filters:
A dictionary of filters to apply when retrieving documents.
:param top_k:
The maximum number of documents to return per retriever.
:param max_workers:
The maximum number of threads to use for parallel retrieval.
"""
self.retrievers = retrievers
self.filters = filters
self.top_k = top_k
self.max_workers = max_workers
self._is_warmed_up = False
def _resolve_retrievers(self, active_retrievers: list[str] | None) -> dict[str, TextRetriever]:
"""
Returns the subset of retrievers to run based on the active_retrievers list.
:param active_retrievers:
A list of retriever names to run. If None, all retrievers are returned.
:returns:
A dictionary of retriever names to retriever instances.
:raises ValueError:
If any name in `active_retrievers` does not match a retriever name.
"""
if active_retrievers is None:
return self.retrievers
unknown = set(active_retrievers) - self.retrievers.keys()
if unknown:
raise ValueError(
f"Unknown retriever name(s): {sorted(unknown)}. Available retrievers: {sorted(self.retrievers.keys())}"
)
return {name: self.retrievers[name] for name in active_retrievers}
def warm_up(self) -> None:
"""
Warm up the retrievers if any has a warm_up method.
"""
if self._is_warmed_up:
return
for retriever in self.retrievers.values():
if hasattr(retriever, "warm_up") and callable(retriever.warm_up):
retriever.warm_up()
self._is_warmed_up = True
@component.output_types(documents=list[Document])
def run(
self,
query: str,
filters: dict[str, Any] | None = None,
top_k: int | None = None,
*,
active_retrievers: list[str] | None = None,
) -> dict[str, list[Document]]:
"""
Runs retrievers in parallel on the given query and returns deduplicated results.
:param query:
The query to run the retrievers on.
:param filters:
Filters to apply. Defaults to the value set at initialization.
:param top_k:
Maximum documents to return per retriever. Defaults to the value set at initialization.
:param active_retrievers:
Names of retrievers to run. Defaults to all. Must match keys in the `retrievers` dictionary.
:returns:
A dictionary with the keys:
- "documents": A deduplicated list of retrieved documents.
:raises ValueError:
If any name in `active_retrievers` does not match a retriever name.
"""
if not self._is_warmed_up:
self.warm_up()
resolved_top_k = top_k if top_k is not None else self.top_k
resolved_filters = filters if filters is not None else self.filters
retrievers_to_run = self._resolve_retrievers(active_retrievers)
all_documents: list[Document] = []
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
future_to_name = {
executor.submit(retriever.run, query=query, filters=resolved_filters, top_k=resolved_top_k): name
for name, retriever in retrievers_to_run.items()
}
for future in as_completed(future_to_name):
name = future_to_name[future]
try:
all_documents.extend(future.result().get("documents", []))
except Exception as e:
raise RuntimeError(f"Retriever '{name}' failed: {e}") from e
return {"documents": _deduplicate_documents(all_documents)}
@component.output_types(documents=list[Document])
async def run_async(
self,
query: str,
filters: dict[str, Any] | None = None,
top_k: int | None = None,
*,
active_retrievers: list[str] | None = None,
) -> dict[str, list[Document]]:
"""
Runs retrievers concurrently on the given query and returns deduplicated results.
Uses each retriever's `run_async` method if available, otherwise runs `run` in a thread executor.
:param query:
The query to run the retrievers on.
:param filters:
Filters to apply. Defaults to the value set at initialization.
:param top_k:
Maximum documents to return per retriever. Defaults to the value set at initialization.
:param active_retrievers:
Names of retrievers to run. Defaults to all. Must match keys in the `retrievers` dictionary.
:returns:
A dictionary with the keys:
- "documents": A deduplicated list of retrieved documents.
:raises ValueError:
If any name in `active_retrievers` does not match a retriever name.
"""
if not self._is_warmed_up:
self.warm_up()
resolved_top_k = top_k if top_k is not None else self.top_k
resolved_filters = filters if filters is not None else self.filters
retrievers_to_run = self._resolve_retrievers(active_retrievers)
loop = asyncio.get_running_loop()
async def _run_one(name: str, retriever: TextRetriever) -> list[Document]:
try:
if hasattr(retriever, "run_async") and callable(retriever.run_async):
result = await retriever.run_async(query=query, filters=resolved_filters, top_k=resolved_top_k)
else:
result = await loop.run_in_executor(
None, lambda: retriever.run(query=query, filters=resolved_filters, top_k=resolved_top_k)
)
return result.get("documents", [])
except Exception as e:
raise RuntimeError(f"Retriever '{name}' failed: {e}") from e
results = await asyncio.gather(*[_run_one(name, r) for name, r in retrievers_to_run.items()])
all_documents: list[Document] = []
for docs in results:
all_documents.extend(docs)
return {"documents": _deduplicate_documents(all_documents)}
def to_dict(self) -> dict[str, Any]:
"""
Serializes the component to a dictionary.
:returns:
Dictionary with serialized data.
"""
return default_to_dict(
self,
retrievers={name: component_to_dict(obj=r, name=name) for name, r in self.retrievers.items()},
filters=self.filters,
top_k=self.top_k,
max_workers=self.max_workers,
)
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "MultiRetriever":
"""
Creates an instance of the component from a dictionary.
:param data:
Dictionary with the data to create the component.
"""
retrievers_data = data.get("init_parameters", {}).get("retrievers", {})
if retrievers_data:
retrievers = {}
for name, retriever_data in retrievers_data.items():
try:
imported_class = import_class_by_name(retriever_data["type"])
except ImportError as e:
raise ImportError(
f"Could not import class {retriever_data['type']} for retriever '{name}'. Error: {str(e)}"
) from e
retrievers[name] = component_from_dict(cls=imported_class, data=retriever_data, name=name)
data["init_parameters"]["retrievers"] = retrievers
return default_from_dict(cls, data)
@@ -0,0 +1,129 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any
from haystack import Document, component, default_from_dict, default_to_dict
from haystack.components.embedders.types.protocol import TextEmbedder
from haystack.components.retrievers.types import EmbeddingRetriever
from haystack.core.serialization import component_to_dict
@component
class TextEmbeddingRetriever:
"""
A component that retrieves documents using a query with an embedding-based retriever.
This component takes a text query, converts it to an embedding using a text embedder, and then uses an
embedding-based retriever to find relevant documents.
The results are sorted by relevance score.
### Usage example
```python
from haystack import Document
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.document_stores.types import DuplicatePolicy
from haystack.components.embedders import SentenceTransformersTextEmbedder, SentenceTransformersDocumentEmbedder
from haystack.components.retrievers import InMemoryEmbeddingRetriever, TextEmbeddingRetriever
from haystack.components.writers import DocumentWriter
documents = [
Document(content="Renewable energy is energy that is collected from renewable resources."),
Document(content="Solar energy is a type of green energy that is harnessed from the sun."),
Document(content="Wind energy is another type of green energy that is generated by wind turbines."),
Document(content="Geothermal energy is heat that comes from the sub-surface of the earth."),
Document(content="Biomass energy is produced from organic materials, such as plant and animal waste."),
Document(content="Fossil fuels, such as coal, oil, and natural gas, are non-renewable energy sources."),
]
# Populate the document store
doc_store = InMemoryDocumentStore()
doc_embedder = SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2")
doc_writer = DocumentWriter(document_store=doc_store, policy=DuplicatePolicy.SKIP)
documents = doc_embedder.run(documents)["documents"]
doc_writer.run(documents=documents)
# Run the retriever
in_memory_retriever = InMemoryEmbeddingRetriever(document_store=doc_store, top_k=1)
text_embedder = SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2")
retriever = TextEmbeddingRetriever(retriever=in_memory_retriever, text_embedder=text_embedder)
result = retriever.run(query="Geothermal energy")
for doc in result["documents"]:
print(f"Content: {doc.content}, Score: {doc.score}")
# >> Content: Geothermal energy is heat that comes from the sub-surface of the earth., Score: 0.8509603046266574
```
"""
def __init__(self, *, retriever: EmbeddingRetriever, text_embedder: TextEmbedder) -> None:
"""
Initialize TextEmbeddingRetriever.
:param retriever: The embedding-based retriever to use for document retrieval.
:param text_embedder: The text embedder to convert a text query to an embedding.
"""
self.retriever = retriever
self.text_embedder = text_embedder
self._is_warmed_up = False
def warm_up(self) -> None:
"""
Warm up the text embedder and the retriever if any has a warm_up method.
"""
if not self._is_warmed_up:
if hasattr(self.text_embedder, "warm_up") and callable(self.text_embedder.warm_up):
self.text_embedder.warm_up()
if hasattr(self.retriever, "warm_up") and callable(self.retriever.warm_up):
self.retriever.warm_up()
self._is_warmed_up = True
@component.output_types(documents=list[Document])
def run(
self, query: str, filters: dict[str, Any] | None = None, top_k: int | None = None
) -> dict[str, list[Document]]:
"""
Retrieve documents using a single query.
:param query: The query to retrieve documents for.
:param filters: A dictionary of filters to apply when retrieving documents.
:param top_k: The maximum number of documents to return.
:returns:
A dictionary containing:
- `documents`: List of retrieved documents sorted by relevance score.
"""
if not self._is_warmed_up:
self.warm_up()
embedding_result = self.text_embedder.run(text=query)
result = self.retriever.run(query_embedding=embedding_result["embedding"], filters=filters, top_k=top_k)
docs: list[Document] = result["documents"]
# sort
docs.sort(key=lambda x: x.score or 0.0, reverse=True)
return {"documents": docs}
def to_dict(self) -> dict[str, Any]:
"""
Serializes the component to a dictionary.
:returns:
A dictionary representing the serialized component.
"""
return default_to_dict(
self,
retriever=component_to_dict(obj=self.retriever, name="retriever"),
text_embedder=component_to_dict(obj=self.text_embedder, name="text_embedder"),
)
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "TextEmbeddingRetriever":
"""
Deserializes the component from a dictionary.
:param data: The dictionary to deserialize from.
:returns:
The deserialized component.
"""
return default_from_dict(cls, data)
+47
View File
@@ -0,0 +1,47 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
import functools
import warnings
from typing import Any, TypeVar
T = TypeVar("T")
def _experimental(cls: type[T]) -> type[T]:
"""
Class decorator that marks a Haystack component as experimental.
Components decorated with @experimental are subject to breaking changes
or removal in future releases without prior deprecation notice.
## Usage example
@_experimental
@component
class MyComponent:
...
"""
# getattr/setattr are intentional here: direct attribute access (cls.__init__, cls.__init__ = ...)
# triggers mypy [misc] and [attr-defined] errors because T is an unbound TypeVar.
# noqa comments suppress ruff B009/B010 which would auto-revert these back to direct access.
original_init: Any = getattr(cls, "__init__") # noqa: B009
@functools.wraps(original_init)
def new_init(self: Any, *args: Any, **kwargs: Any) -> None:
warnings.warn(
f"'{cls.__name__}' is an experimental component and may change or be removed "
"in future releases without prior deprecation notice. ",
ExperimentalWarning,
stacklevel=2,
)
original_init(self, *args, **kwargs)
setattr(cls, "__init__", new_init) # noqa: B010
setattr(cls, "__experimental__", True) # noqa: B010
return cls
class ExperimentalWarning(UserWarning):
"""Warning emitted when an experimental Haystack component is instantiated."""
+1 -1
View File
@@ -2,7 +2,7 @@ loaders:
- search_path: [../haystack/components/retrievers]
modules: ["auto_merging_retriever", "filter_retriever", "in_memory/bm25_retriever",
"in_memory/embedding_retriever", "multi_query_embedding_retriever", "multi_query_text_retriever",
"sentence_window_retriever"]
"multi_retriever", "text_embedding_retriever", "sentence_window_retriever"]
processors:
- type: filter
documented_only: true
@@ -0,0 +1,14 @@
---
features:
- |
Added two new retriever components: ``MultiRetriever`` and ``TextEmbeddingRetriever``.
``MultiRetriever`` is marked as experimental and may change or be removed in future releases without prior deprecation notice.
An ``ExperimentalWarning`` is printed when initializing this component.
``MultiRetriever`` combines multiple text retrievers into a single component.
All text retrievers are queried in parallel and their results are deduplicated before being returned.
Use the ``active_retrievers`` parameter to enable or disable specific retrievers at runtime.
``TextEmbeddingRetriever`` wraps an embedding-based retriever together with a text embedder into a single
component that implements the ``TextRetriever`` protocol, making it compatible with ``MultiRetriever``.
@@ -0,0 +1,455 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any
from unittest.mock import ANY
import pytest
from haystack import Document, component
from haystack.components.embedders import SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder
from haystack.components.retrievers import (
InMemoryBM25Retriever,
InMemoryEmbeddingRetriever,
MultiRetriever,
TextEmbeddingRetriever,
)
from haystack.components.writers import DocumentWriter
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.document_stores.types import DuplicatePolicy
from haystack.utils.experimental import ExperimentalWarning
pytestmark = pytest.mark.filterwarnings("ignore::haystack.utils.experimental.ExperimentalWarning")
@component
class MockRetriever:
def __init__(self, documents: list[Document] | None = None):
self.documents = documents or []
@component.output_types(documents=list[Document])
def run(self, query: str, filters: dict[str, Any] | None = None, top_k: int | None = None):
return {"documents": self.documents}
@component
class FailingRetriever:
@component.output_types(documents=list[Document])
def run(self, query: str, filters: dict[str, Any] | None = None, top_k: int | None = None):
raise RuntimeError("connection error")
@pytest.fixture
def sample_documents():
return [
Document(
content="Renewable energy is energy that is collected from renewable resources.",
meta={"category": "renewable"},
id="doc1",
),
Document(
content="Solar energy is a type of green energy that is harnessed from the sun.",
meta={"category": "solar"},
id="doc2",
),
Document(
content="Wind energy is another type of green energy that is generated by wind turbines.",
meta={"category": "wind"},
id="doc3",
),
Document(
content="Geothermal energy is heat that comes from the sub-surface of the earth.",
meta={"category": "geothermal"},
id="doc4",
),
Document(
content="Fossil fuels, such as coal, oil, and natural gas, are non-renewable energy sources.",
meta={"category": "fossil"},
id="doc5",
),
]
@pytest.fixture
def document_store_with_embeddings(sample_documents):
"""Create a document store populated with embedded documents."""
document_store = InMemoryDocumentStore()
doc_embedder = SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2")
doc_writer = DocumentWriter(document_store=document_store, policy=DuplicatePolicy.SKIP)
embedded_docs = doc_embedder.run(sample_documents)["documents"]
doc_writer.run(documents=embedded_docs)
return document_store
@pytest.fixture
def bm25_retriever(document_store_with_embeddings):
return InMemoryBM25Retriever(document_store=document_store_with_embeddings)
@pytest.fixture
def embedding_retriever(document_store_with_embeddings):
return TextEmbeddingRetriever(
retriever=InMemoryEmbeddingRetriever(document_store=document_store_with_embeddings),
text_embedder=SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"),
)
class TestMultiRetriever:
def test_init_default_parameters(self):
retrievers = {"mock": MockRetriever()}
retriever = MultiRetriever(retrievers=retrievers)
assert retriever.retrievers == retrievers
assert retriever.filters is None
assert retriever.top_k == 10
assert retriever.max_workers == 4
def test_init_custom_parameters(self):
retrievers = {"mock": MockRetriever()}
retriever = MultiRetriever(retrievers=retrievers, filters={"field": "meta.category"}, top_k=5, max_workers=2)
assert retriever.retrievers == retrievers
assert retriever.filters == {"field": "meta.category"}
assert retriever.top_k == 5
assert retriever.max_workers == 2
def test_run_with_empty_document_store(self):
retriever = MultiRetriever(retrievers={"mock": MockRetriever()})
result = retriever.run(query="green energy")
assert "documents" in result
assert result["documents"] == []
def test_run_combines_results_from_multiple_retrievers(self, sample_documents):
retriever = MultiRetriever(
retrievers={
"a": MockRetriever(documents=[sample_documents[0]]),
"b": MockRetriever(documents=[sample_documents[1]]),
},
max_workers=2,
)
result = retriever.run(query="energy")
assert len(result["documents"]) == 2
assert {doc.id for doc in result["documents"]} == {"doc1", "doc2"}
def test_run_deduplicates_results(self, sample_documents):
retriever = MultiRetriever(
retrievers={
"c": MockRetriever(documents=[sample_documents[0], sample_documents[1]]),
"d": MockRetriever(documents=[sample_documents[0]]),
},
max_workers=2,
)
result = retriever.run(query="energy")
assert len(result["documents"]) == 2
ids = [doc.id for doc in result["documents"]]
assert ids.count("doc1") == 1
def test_run_resolves_filters_and_top_k(self):
received: dict = {}
@component
class CapturingRetriever:
@component.output_types(documents=list[Document])
def run(self, query: str, filters: dict[str, Any] | None = None, top_k: int | None = None):
received["filters"] = filters
received["top_k"] = top_k
return {"documents": []}
retriever = MultiRetriever(
retrievers={"capturing": CapturingRetriever()}, filters={"field": "meta.category"}, top_k=5
)
# Should use init-time values when not overridden
retriever.run(query="energy")
assert received["filters"] == {"field": "meta.category"}
assert received["top_k"] == 5
# Should prefer run-time values when provided
retriever.run(query="energy", filters={"field": "meta.other"}, top_k=2)
assert received["filters"] == {"field": "meta.other"}
assert received["top_k"] == 2
def test_run_with_active_retrievers(self, sample_documents):
retriever = MultiRetriever(
retrievers={"a": MockRetriever([sample_documents[0]]), "b": MockRetriever([sample_documents[1]])}
)
# Only run retriever "a"
result = retriever.run(query="energy", active_retrievers=["a"])
assert len(result["documents"]) == 1
assert result["documents"][0].id == "doc1"
def test_run_with_unknown_active_retriever_raises(self):
retriever = MultiRetriever(retrievers={"mock": MockRetriever()})
with pytest.raises(ValueError, match="Unknown retriever name"):
retriever.run(query="energy", active_retrievers=["nonexistent"])
def test_run_retriever_failure_raises_with_name(self):
retriever = MultiRetriever(retrievers={"failing": FailingRetriever()})
with pytest.raises(RuntimeError, match="Retriever 'failing' failed"):
retriever.run(query="energy")
def test_to_dict(self):
retriever = MultiRetriever(
retrievers={"bm25": InMemoryBM25Retriever(document_store=InMemoryDocumentStore())},
filters=None,
top_k=5,
max_workers=2,
)
result = retriever.to_dict()
assert result == {
"type": "haystack.components.retrievers.multi_retriever.MultiRetriever",
"init_parameters": {
"retrievers": {
"bm25": {
"type": "haystack.components.retrievers.in_memory.bm25_retriever.InMemoryBM25Retriever",
"init_parameters": {
"document_store": {
"type": "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore",
"init_parameters": {
"bm25_tokenization_regex": "(?u)\\b\\w+\\b",
"bm25_algorithm": "BM25L",
"bm25_parameters": {},
"embedding_similarity_function": "dot_product",
"index": ANY,
"return_embedding": True,
},
},
"filters": None,
"top_k": 10,
"scale_score": False,
"filter_policy": "replace",
},
}
},
"filters": None,
"top_k": 5,
"max_workers": 2,
},
}
def test_from_dict(self):
data = {
"type": "haystack.components.retrievers.multi_retriever.MultiRetriever",
"init_parameters": {
"retrievers": {
"bm25": {
"type": "haystack.components.retrievers.in_memory.bm25_retriever.InMemoryBM25Retriever",
"init_parameters": {
"document_store": {
"type": "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore",
"init_parameters": {
"bm25_tokenization_regex": "(?u)\\b\\w\\w+\\b",
"bm25_algorithm": "BM25L",
"bm25_parameters": {},
"embedding_similarity_function": "dot_product",
"index": "4bb5369d-779f-487b-9c16-3c40f503438b",
"return_embedding": True,
},
},
"filters": None,
"top_k": 10,
"scale_score": False,
"filter_policy": "replace",
},
}
},
"filters": None,
"top_k": 5,
"max_workers": 2,
},
}
result = MultiRetriever.from_dict(data)
assert isinstance(result, MultiRetriever)
assert len(result.retrievers) == 1
assert "bm25" in result.retrievers
assert isinstance(result.retrievers["bm25"], InMemoryBM25Retriever)
assert result.top_k == 5
assert result.max_workers == 2
def test_from_dict_with_no_retrievers(self):
data = {
"type": "haystack.components.retrievers.multi_retriever.MultiRetriever",
"init_parameters": {"retrievers": {}, "filters": None, "top_k": 10, "max_workers": 4},
}
result = MultiRetriever.from_dict(data)
assert isinstance(result, MultiRetriever)
assert result.retrievers == {}
def test_from_dict_with_unknown_retriever_type_raises(self):
data = {
"type": "haystack.components.retrievers.multi_retriever.MultiRetriever",
"init_parameters": {
"retrievers": {
"bad": {"type": "haystack.components.retrievers.NonExistentRetriever", "init_parameters": {}}
},
"filters": None,
"top_k": 10,
"max_workers": 4,
},
}
with pytest.raises(ImportError, match="Could not import class"):
MultiRetriever.from_dict(data)
@pytest.mark.integration
@pytest.mark.slow
def test_run_with_filters(self, del_hf_env_vars, bm25_retriever, embedding_retriever):
retriever = MultiRetriever(retrievers={"bm25": bm25_retriever, "embedding": embedding_retriever})
result = retriever.run(query="energy", filters={"field": "meta.category", "operator": "==", "value": "solar"})
assert len(result["documents"]) == 1
assert result["documents"][0].meta["category"] == "solar"
@pytest.mark.integration
@pytest.mark.slow
def test_run_with_top_k(self, del_hf_env_vars, bm25_retriever, embedding_retriever):
retriever = MultiRetriever(retrievers={"bm25": bm25_retriever, "embedding": embedding_retriever})
result = retriever.run(query="energy", top_k=2)
assert len(result["documents"]) == 2
@pytest.mark.integration
@pytest.mark.slow
def test_run_with_active_retrievers_integration(self, del_hf_env_vars, bm25_retriever, embedding_retriever):
retriever = MultiRetriever(retrievers={"bm25": bm25_retriever, "embedding": embedding_retriever})
result_bm25_active = retriever.run(query="energy", active_retrievers=["bm25"])
result_bm25 = bm25_retriever.run(query="energy")
assert result_bm25_active == result_bm25
class TestMultiRetrieverAsync:
@pytest.mark.asyncio
async def test_run_async_with_empty_results(self):
retriever = MultiRetriever(retrievers={"mock": MockRetriever()})
result = await retriever.run_async(query="green energy")
assert "documents" in result
assert result["documents"] == []
@pytest.mark.asyncio
async def test_run_async_combines_results_from_multiple_retrievers(self, sample_documents):
retriever = MultiRetriever(
retrievers={
"a": MockRetriever(documents=[sample_documents[0]]),
"b": MockRetriever(documents=[sample_documents[1]]),
}
)
result = await retriever.run_async(query="energy")
assert len(result["documents"]) == 2
assert {doc.id for doc in result["documents"]} == {"doc1", "doc2"}
@pytest.mark.asyncio
async def test_run_async_deduplicates_results(self, sample_documents):
retriever = MultiRetriever(
retrievers={
"c": MockRetriever(documents=[sample_documents[0], sample_documents[1]]),
"d": MockRetriever(documents=[sample_documents[0]]),
}
)
result = await retriever.run_async(query="energy")
assert len(result["documents"]) == 2
assert [doc.id for doc in result["documents"]].count("doc1") == 1
@pytest.mark.asyncio
async def test_run_async_resolves_filters_and_top_k(self):
received: dict = {}
@component
class CapturingRetriever:
@component.output_types(documents=list[Document])
def run(self, query: str, filters: dict[str, Any] | None = None, top_k: int | None = None):
received["filters"] = filters
received["top_k"] = top_k
return {"documents": []}
retriever = MultiRetriever(
retrievers={"capturing": CapturingRetriever()}, filters={"field": "meta.category"}, top_k=5
)
await retriever.run_async(query="energy")
assert received["filters"] == {"field": "meta.category"}
assert received["top_k"] == 5
await retriever.run_async(query="energy", filters={"field": "meta.other"}, top_k=2)
assert received["filters"] == {"field": "meta.other"}
assert received["top_k"] == 2
@pytest.mark.asyncio
async def test_run_async_with_active_retrievers(self, sample_documents):
retriever = MultiRetriever(
retrievers={"a": MockRetriever([sample_documents[0]]), "b": MockRetriever([sample_documents[1]])}
)
result = await retriever.run_async(query="energy", active_retrievers=["a"])
assert len(result["documents"]) == 1
assert result["documents"][0].id == "doc1"
@pytest.mark.asyncio
async def test_run_async_with_unknown_active_retriever_raises(self):
retriever = MultiRetriever(retrievers={"mock": MockRetriever()})
with pytest.raises(ValueError, match="Unknown retriever name"):
await retriever.run_async(query="energy", active_retrievers=["nonexistent"])
@pytest.mark.asyncio
async def test_run_async_retriever_failure_raises_with_name(self):
retriever = MultiRetriever(retrievers={"failing": FailingRetriever()})
with pytest.raises(RuntimeError, match="Retriever 'failing' failed"):
await retriever.run_async(query="energy")
@pytest.mark.asyncio
async def test_run_async_uses_run_async_on_retriever_if_available(self):
@component
class AsyncCapableRetriever:
def __init__(self):
self.used_async = False
@component.output_types(documents=list[Document])
def run(self, query: str, filters: dict[str, Any] | None = None, top_k: int | None = None):
return {"documents": []}
@component.output_types(documents=list[Document])
async def run_async(self, query: str, filters: dict[str, Any] | None = None, top_k: int | None = None):
self.used_async = True
return {"documents": [Document(content="async result", id="async1")]}
inner = AsyncCapableRetriever()
retriever = MultiRetriever(retrievers={"async_capable": inner})
result = await retriever.run_async(query="energy")
assert inner.used_async is True
assert len(result["documents"]) == 1
assert result["documents"][0].id == "async1"
@pytest.mark.integration
@pytest.mark.slow
@pytest.mark.asyncio
async def test_run_async_with_filters(self, del_hf_env_vars, bm25_retriever, embedding_retriever):
retriever = MultiRetriever(retrievers={"bm25": bm25_retriever, "embedding": embedding_retriever})
result = await retriever.run_async(
query="energy", filters={"field": "meta.category", "operator": "==", "value": "solar"}
)
assert len(result["documents"]) == 1
assert result["documents"][0].meta["category"] == "solar"
@pytest.mark.integration
@pytest.mark.slow
@pytest.mark.asyncio
async def test_run_async_with_top_k(self, del_hf_env_vars, bm25_retriever, embedding_retriever):
retriever = MultiRetriever(retrievers={"bm25": bm25_retriever, "embedding": embedding_retriever})
result = await retriever.run_async(query="energy", top_k=2)
assert len(result["documents"]) == 2
@pytest.mark.integration
@pytest.mark.slow
@pytest.mark.asyncio
async def test_run_async_with_active_retrievers_integration(
self, del_hf_env_vars, bm25_retriever, embedding_retriever
):
retriever = MultiRetriever(retrievers={"bm25": bm25_retriever, "embedding": embedding_retriever})
result_bm25_active = await retriever.run_async(query="energy", active_retrievers=["bm25"])
result_bm25 = await bm25_retriever.run_async(query="energy")
assert result_bm25_active == result_bm25
class TestMultiRetrieverExperimental:
@pytest.mark.filterwarnings("always::haystack.utils.experimental.ExperimentalWarning")
def test_emits_experimental_warning_on_init(self):
with pytest.warns(ExperimentalWarning, match="MultiRetriever.*experimental"):
MultiRetriever(retrievers={"mock": MockRetriever()})
@pytest.mark.filterwarnings("always::haystack.utils.experimental.ExperimentalWarning")
def test_experimental_attribute_is_set(self):
assert getattr(MultiRetriever, "__experimental__", False) is True
@@ -0,0 +1,193 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any
from unittest.mock import ANY
import numpy as np
import pytest
from haystack import Document, component
from haystack.components.embedders import SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder
from haystack.components.retrievers import InMemoryEmbeddingRetriever, TextEmbeddingRetriever
from haystack.components.writers import DocumentWriter
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.document_stores.types import DuplicatePolicy
@component
class MockTextEmbedder:
@component.output_types(embedding=list[float])
def run(self, text: str) -> dict[str, list[float]]:
return {"embedding": np.ones(384).tolist()}
class TestTextEmbeddingRetriever:
@pytest.fixture
def sample_documents(self):
return [
Document(content="Renewable energy is energy that is collected from renewable resources."),
Document(content="Solar energy is a type of green energy that is harnessed from the sun."),
Document(content="Wind energy is another type of green energy that is generated by wind turbines."),
Document(content="Geothermal energy is heat that comes from the sub-surface of the earth."),
Document(content="Fossil fuels, such as coal, oil, and natural gas, are non-renewable energy sources."),
]
@pytest.fixture
def document_store_with_embeddings(self, sample_documents):
"""Create a document store populated with embedded documents."""
document_store = InMemoryDocumentStore()
doc_embedder = SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2")
doc_writer = DocumentWriter(document_store=document_store, policy=DuplicatePolicy.SKIP)
embedded_docs = doc_embedder.run(sample_documents)["documents"]
doc_writer.run(documents=embedded_docs)
return document_store
def test_init(self):
embedding_retriever = InMemoryEmbeddingRetriever(document_store=InMemoryDocumentStore())
text_embedder = MockTextEmbedder()
retriever = TextEmbeddingRetriever(retriever=embedding_retriever, text_embedder=text_embedder)
assert retriever.retriever == embedding_retriever
assert retriever.text_embedder == text_embedder
def test_run_with_empty_document_store(self):
retriever = TextEmbeddingRetriever(
retriever=InMemoryEmbeddingRetriever(document_store=InMemoryDocumentStore()),
text_embedder=MockTextEmbedder(),
)
result = retriever.run(query="green energy")
assert "documents" in result
assert result["documents"] == []
def test_run_returns_documents_sorted_by_score(self):
doc_high = Document(content="Solar energy", id="doc1", score=0.9)
doc_low = Document(content="Fossil fuels", id="doc2", score=0.3)
doc_mid = Document(content="Wind energy", id="doc3", score=0.6)
@component
class MockRetriever:
@component.output_types(documents=list[Document])
def run(
self, query_embedding: list[float], filters: dict[str, Any] | None = None, top_k: int | None = None
) -> dict[str, list[Document]]:
return {"documents": [doc_low, doc_high, doc_mid]}
retriever = TextEmbeddingRetriever(retriever=MockRetriever(), text_embedder=MockTextEmbedder())
result = retriever.run(query="energy")
scores = [doc.score for doc in result["documents"]]
assert scores == sorted(scores, reverse=True)
def test_to_dict(self):
retriever = TextEmbeddingRetriever(
retriever=InMemoryEmbeddingRetriever(document_store=InMemoryDocumentStore()),
text_embedder=MockTextEmbedder(),
)
result = retriever.to_dict()
assert result == {
"type": "haystack.components.retrievers.text_embedding_retriever.TextEmbeddingRetriever",
"init_parameters": {
"retriever": {
"type": "haystack.components.retrievers.in_memory.embedding_retriever.InMemoryEmbeddingRetriever",
"init_parameters": {
"document_store": {
"type": "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore",
"init_parameters": {
"bm25_tokenization_regex": "(?u)\\b\\w+\\b",
"bm25_algorithm": "BM25L",
"bm25_parameters": {},
"embedding_similarity_function": "dot_product",
"index": ANY,
"return_embedding": True,
},
},
"filters": None,
"top_k": 10,
"scale_score": False,
"return_embedding": False,
"filter_policy": "replace",
},
},
"text_embedder": {
"type": "retrievers.test_text_embedding_retriever.MockTextEmbedder",
"init_parameters": {},
},
},
}
def test_from_dict(self):
data = {
"type": "haystack.components.retrievers.text_embedding_retriever.TextEmbeddingRetriever",
"init_parameters": {
"retriever": {
"type": "haystack.components.retrievers.in_memory.embedding_retriever.InMemoryEmbeddingRetriever",
"init_parameters": {
"document_store": {
"type": "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore",
"init_parameters": {
"bm25_tokenization_regex": "(?u)\\b\\w\\w+\\b",
"bm25_algorithm": "BM25L",
"bm25_parameters": {},
"embedding_similarity_function": "dot_product",
"index": "4bb5369d-779f-487b-9c16-3c40f503438b",
"return_embedding": True,
},
},
"filters": None,
"top_k": 10,
"scale_score": False,
"return_embedding": False,
"filter_policy": "replace",
},
},
"text_embedder": {
"type": "haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder", # noqa E501
"init_parameters": {
"model": "sentence-transformers/all-MiniLM-L6-v2",
"token": {"type": "env_var", "env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False},
"prefix": "",
"suffix": "",
"batch_size": 32,
"progress_bar": True,
"normalize_embeddings": False,
"trust_remote_code": False,
"local_files_only": False,
"truncate_dim": None,
"model_kwargs": None,
"tokenizer_kwargs": None,
"config_kwargs": None,
"precision": "float32",
"encode_kwargs": None,
"backend": "torch",
},
},
},
}
result = TextEmbeddingRetriever.from_dict(data)
assert isinstance(result, TextEmbeddingRetriever)
assert isinstance(result.retriever, InMemoryEmbeddingRetriever)
assert isinstance(result.text_embedder, SentenceTransformersTextEmbedder)
@pytest.mark.integration
@pytest.mark.slow
def test_run_with_filters(self, del_hf_env_vars, document_store_with_embeddings):
retriever = TextEmbeddingRetriever(
retriever=InMemoryEmbeddingRetriever(document_store=document_store_with_embeddings),
text_embedder=SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"),
)
result = retriever.run(query="energy", filters={"field": "meta.category", "operator": "==", "value": "solar"})
assert "documents" in result
assert all(doc.meta.get("category") == "solar" for doc in result["documents"])
@pytest.mark.integration
@pytest.mark.slow
def test_run_with_top_k(self, del_hf_env_vars, document_store_with_embeddings):
retriever = TextEmbeddingRetriever(
retriever=InMemoryEmbeddingRetriever(document_store=document_store_with_embeddings),
text_embedder=SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"),
)
result = retriever.run(query="energy", top_k=2)
assert "documents" in result
assert len(result["documents"]) <= 2
+87
View File
@@ -0,0 +1,87 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
import pytest
from haystack import component
from haystack.utils.experimental import ExperimentalWarning, _experimental
class TestExperimentalDecorator:
def test_emits_experimental_warning_on_init(self):
@_experimental
@component
class MyComponent:
@component.output_types(value=int)
def run(self, value: int) -> dict:
return {"value": value}
with pytest.warns(ExperimentalWarning):
MyComponent()
def test_warning_message_contains_class_name(self):
@_experimental
@component
class MyComponent:
@component.output_types(value=int)
def run(self, value: int) -> dict:
return {"value": value}
with pytest.warns(ExperimentalWarning, match="MyComponent"):
MyComponent()
def test_sets_experimental_attribute(self):
@_experimental
@component
class MyComponent:
@component.output_types(value=int)
def run(self, value: int) -> dict:
return {"value": value}
assert MyComponent.__experimental__ is True
def test_passes_args_and_kwargs_to_init(self):
@_experimental
@component
class MyComponent:
def __init__(self, value: int, label: str = "default"):
self.value = value
self.label = label
@component.output_types(value=int)
def run(self, value: int) -> dict:
return {"value": value}
with pytest.warns(ExperimentalWarning):
instance = MyComponent(42, label="custom")
assert instance.value == 42
assert instance.label == "custom"
def test_preserves_init_name(self):
@_experimental
@component
class MyComponent:
@component.output_types(value=int)
def run(self, value: int) -> dict:
return {"value": value}
assert MyComponent.__init__.__name__ == "__init__"
def test_experimental_warning_is_user_warning_subclass(self):
assert issubclass(ExperimentalWarning, UserWarning)
def test_warning_emitted_on_every_instantiation(self):
@_experimental
@component
class MyComponent:
@component.output_types(value=int)
def run(self, value: int) -> dict:
return {"value": value}
with pytest.warns(ExperimentalWarning):
MyComponent()
with pytest.warns(ExperimentalWarning):
MyComponent()