Skip to content

Commit f88e6a2

Browse files
feat(dense): prefer OpenAI embeddings; split dense extras; tighten indexing tests
Why: - sentence-transformers pulls torch + CUDA wheels on many platforms; this makes the default "dense" extra hard to install and test in constrained environments. What: - Split extras: "dense" = FAISS only; "dense-st" = SentenceTransformers. - Prefer OpenAI embeddings when OPENAI_API_KEY is configured; fall back to SentenceTransformers when installed. - Align bootstrap/build_index and API router with the same embedder-selection behavior. - Allow HistorySqlStorage to accept a session_factory for testability. - Replace the SentenceTransformers integration test with an OpenAIEmbedder integration test that stubs the SDK. - Add an end-to-end dense+hybrid test that works with the FAISS numpy fallback. - Strengthen build_index sparse integration assertions. Notes: - uv.lock was updated to reflect the new extras; regeneration requires network access.
1 parent 867c3df commit f88e6a2

File tree

13 files changed

+236
-64
lines changed

13 files changed

+236
-64
lines changed

pyproject.toml

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,12 @@ dependencies = [
6666

6767
[project.optional-dependencies]
6868
dense = [
69-
"sentence-transformers>=2.7,<3.0",
7069
"faiss-cpu>=1.8.0,<2.0"
7170
]
71+
dense-st = [
72+
# Heavy dependency chain (torch/transformers). Prefer OpenAI embeddings when possible.
73+
"sentence-transformers>=2.7,<3.0"
74+
]
7275
dev = [
7376
# Testing
7477
"pytest>=8.2,<9.0",
@@ -130,8 +133,8 @@ monitoring = [
130133
"sentry-sdk[fastapi]>=1.40,<2.0"
131134
]
132135
all = [
133-
"sentence-transformers>=2.7,<3.0",
134136
"faiss-cpu>=1.8.0,<2.0",
137+
"sentence-transformers>=2.7,<3.0",
135138
"pytest>=8.2,<9.0",
136139
"pytest-cov>=6.1.1,<7.0",
137140
"pytest-asyncio>=0.23,<1.0",
@@ -253,9 +256,8 @@ exclude = "(?x)^(tests/|scripts/)"
253256

254257
plugins = ["pydantic.mypy"]
255258

256-
# Afinado del plugin de Pydantic 2
257-
[mypy.plugins.dummy] # sección ficticia para TOML válidas; el plugin lee vía [tool.mypy]
258-
# (Opciones del plugin se pasan igual bajo [tool.mypy])
259+
# Afinado del plugin de Pydantic (config en pyproject)
260+
[tool.pydantic-mypy]
259261
init_forbid_extra = true
260262
init_typed = true
261263
warn_required_dynamic_aliases = true

src/local_rag_backend/app/api_router.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from local_rag_backend.app.dependencies import get_rag_service
2020
from local_rag_backend.core.services.etl import ETLService
2121
from local_rag_backend.core.services.rag import RagService
22+
from local_rag_backend.infrastructure.embeddings.openai import OpenAIEmbedder
2223
from local_rag_backend.infrastructure.embeddings.sentence_transformers import (
2324
SentenceTransformerEmbedder,
2425
)
@@ -52,7 +53,12 @@
5253
if TYPE_CHECKING:
5354
from sqlalchemy.orm import Session
5455

55-
from local_rag_backend.core.ports import DocumentRepoPort, GeneratorPort, RetrieverPort
56+
from local_rag_backend.core.ports import (
57+
DocumentRepoPort,
58+
EmbedderPort,
59+
GeneratorPort,
60+
RetrieverPort,
61+
)
5662

5763
# ---------------------- Validation Utilities ---------------------- #
5864

@@ -253,7 +259,9 @@ async def ingest_docs(payload: Annotated[IngestRequest, Body(...)]) -> IngestRes
253259
doc_repo = SqlDocumentStorage()
254260

255261
if settings.retrieval_mode in ("dense", "hybrid"):
256-
embedder = SentenceTransformerEmbedder(model_name=settings.st_embedding_model)
262+
embedder: EmbedderPort = OpenAIEmbedder() if settings.openai_api_key else SentenceTransformerEmbedder(
263+
model_name=settings.st_embedding_model
264+
)
257265
vec = FaissVectorStorage(
258266
index_path=settings.index_path,
259267
id_map_path=settings.id_map_path,
@@ -277,7 +285,9 @@ def _build_retriever_from_config(
277285
if cfg.retrieval_mode == "sparse":
278286
return SparseBM25Retriever(documents=corpus, doc_ids=doc_ids, doc_repo=doc_repo)
279287

280-
embedder = SentenceTransformerEmbedder(model_name=settings.st_embedding_model)
288+
embedder: EmbedderPort = OpenAIEmbedder() if settings.openai_api_key else SentenceTransformerEmbedder(
289+
model_name=settings.st_embedding_model
290+
)
281291
faiss_storage = FaissVectorStorage(
282292
index_path=settings.index_path, id_map_path=settings.id_map_path, dim=embedder.dim
283293
)

src/local_rag_backend/infrastructure/embeddings/openai.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,11 @@
55

66
from __future__ import annotations
77

8-
from typing import TYPE_CHECKING
9-
10-
if TYPE_CHECKING:
11-
from collections.abc import Sequence
8+
from collections.abc import Sequence
129

1310
from openai import OpenAI
1411

15-
from local_rag_backend.core.ports import EmbedderPort, Embedding
12+
from local_rag_backend.core.ports import EmbedderPort
1613
from local_rag_backend.settings import settings
1714

1815
_MODEL_DIM: dict[str, int] = {
@@ -24,13 +21,17 @@
2421
DEFAULT_MODEL = settings.openai_embedding_model
2522
DEFAULT_DIM = _MODEL_DIM.get(DEFAULT_MODEL, 1536)
2623

24+
Embedding = Sequence[float]
25+
2726

2827
class OpenAIEmbedder(EmbedderPort):
2928
dim: int # required by the port
3029

3130
def __init__(self, model: str | None = None):
3231
self.model = model or settings.openai_embedding_model
3332
self.dim = _MODEL_DIM.get(self.model, DEFAULT_DIM)
33+
if not settings.openai_api_key:
34+
raise RuntimeError("OPENAI_API_KEY is required to use OpenAI embeddings.")
3435
self.client = OpenAI(api_key=settings.openai_api_key)
3536

3637
def embed(self, texts: Sequence[str]) -> Sequence[Embedding]:

src/local_rag_backend/infrastructure/persistence/sqlalchemy/sql_.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,10 @@ def get_all_documents(self) -> Sequence[DomainDocument]:
5757
class HistorySqlStorage(QAHistoryPort):
5858
"""SQL-based implementation of the history repository port."""
5959

60+
def __init__(self, session_factory: sessionmaker[Session] | None = None):
61+
self._session_factory = session_factory or SessionLocal
62+
6063
def save(self, q: str, a: str, source_ids: Sequence[int]) -> None:
6164
"""Save a question-answer pair to the history table."""
62-
with get_session(SessionLocal) as session:
65+
with get_session(self._session_factory) as session:
6366
add_history(session, q, a, source_ids=list(source_ids))

src/local_rag_backend/scripts/bootstrap.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from importlib import resources
99
from pathlib import Path
10-
from typing import Any
10+
from typing import TYPE_CHECKING, Any
1111

1212
# IMPORTS FOR DYNAMIC DB
1313
from sqlalchemy import create_engine
@@ -20,6 +20,7 @@
2020
default_formatter,
2121
default_preprocess,
2222
)
23+
from local_rag_backend.infrastructure.embeddings.openai import OpenAIEmbedder
2324
from local_rag_backend.infrastructure.embeddings.sentence_transformers import (
2425
SentenceTransformerEmbedder,
2526
)
@@ -32,6 +33,9 @@
3233

3334
DELIMITER = ";"
3435

36+
if TYPE_CHECKING:
37+
from local_rag_backend.core.ports import EmbedderPort
38+
3539

3640
def main(csv_path: str | Path | None = None, **kwargs: Any) -> None:
3741
if "settings" not in kwargs:
@@ -51,7 +55,10 @@ def main(csv_path: str | Path | None = None, **kwargs: Any) -> None:
5155
doc_repo = SqlDocumentStorage(session_factory=session_local)
5256

5357
if settings.retrieval_mode in ["dense", "hybrid"]:
54-
embedder = SentenceTransformerEmbedder(model_name=settings.st_embedding_model)
58+
embedder: EmbedderPort
59+
embedder = OpenAIEmbedder() if settings.openai_api_key else SentenceTransformerEmbedder(
60+
model_name=settings.st_embedding_model
61+
)
5562
vector_repo = FaissVectorStorage(
5663
index_path=settings.index_path,
5764
id_map_path=settings.id_map_path,

src/local_rag_backend/scripts/build_index.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,13 @@
1010
import logging
1111
from importlib import resources
1212
from pathlib import Path
13+
from typing import TYPE_CHECKING
1314

1415
from sqlalchemy import create_engine
1516
from sqlalchemy.orm import sessionmaker
1617
from sqlalchemy.pool import StaticPool
1718

18-
# Import the embedder that will be used for indexing if dense mode
19+
from local_rag_backend.infrastructure.embeddings.openai import OpenAIEmbedder
1920
from local_rag_backend.infrastructure.embeddings.sentence_transformers import (
2021
SentenceTransformerEmbedder,
2122
)
@@ -33,6 +34,9 @@
3334
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
3435
)
3536

37+
if TYPE_CHECKING:
38+
from local_rag_backend.core.ports import EmbedderPort
39+
3640

3741
def main() -> None:
3842
"""
@@ -63,12 +67,15 @@ def main() -> None:
6367
return # Exit if tables can't be created
6468

6569
# 3. Embedder (for dense or hybrid mode)
66-
embedder_for_indexing = None
70+
embedder_for_indexing: EmbedderPort | None = None
6771
if settings.retrieval_mode in ["dense", "hybrid"]:
6872
logger.info(
6973
f"{settings.retrieval_mode.title()} retrieval mode detected. Initializing embedder for indexing."
7074
)
71-
embedder_for_indexing = SentenceTransformerEmbedder(model_name=settings.st_embedding_model)
75+
if settings.openai_api_key:
76+
embedder_for_indexing = OpenAIEmbedder()
77+
else:
78+
embedder_for_indexing = SentenceTransformerEmbedder(model_name=settings.st_embedding_model)
7279

7380
# 4. Use ETL logic directly (similar to bootstrap.py)
7481
try:
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import pytest
2+
from sqlalchemy import create_engine
3+
from sqlalchemy.orm import sessionmaker
4+
5+
from local_rag_backend.core.services.etl import ETLService
6+
from local_rag_backend.core.services.rag import RagService
7+
from local_rag_backend.infrastructure.embeddings import openai as openai_embedder_mod
8+
from local_rag_backend.infrastructure.embeddings.openai import OpenAIEmbedder
9+
from local_rag_backend.infrastructure.persistence.faiss.faiss_ import FaissVectorStorage
10+
from local_rag_backend.infrastructure.persistence.sqlalchemy.base import Base
11+
from local_rag_backend.infrastructure.persistence.sqlalchemy.sql_ import (
12+
HistorySqlStorage,
13+
SqlDocumentStorage,
14+
)
15+
from local_rag_backend.infrastructure.retrieval.dense_faiss import DenseFaissRetriever
16+
from local_rag_backend.infrastructure.retrieval.hybrid import HybridRetriever
17+
from local_rag_backend.infrastructure.retrieval.sparse_bm25 import SparseBM25Retriever
18+
from local_rag_backend.settings import settings
19+
from local_rag_backend.utils import get_corpus_and_ids
20+
21+
22+
class _DummyEmbeddingItem:
23+
def __init__(self, embedding):
24+
self.embedding = embedding
25+
26+
27+
class _DummyEmbeddingsResp:
28+
def __init__(self, vectors):
29+
self.data = [_DummyEmbeddingItem(v) for v in vectors]
30+
31+
32+
class _DummyEmbeddingsAPI:
33+
def create(self, model, input):
34+
# Tiny deterministic 4D embedding; good enough for FAISS L2.
35+
vecs = []
36+
for t in input:
37+
t = str(t)
38+
vecs.append(
39+
[
40+
float(len(t)),
41+
float(sum(ord(c) for c in t) % 17),
42+
float(t.count("a")),
43+
float(t.count("z")),
44+
]
45+
)
46+
return _DummyEmbeddingsResp(vecs)
47+
48+
49+
class _DummyOpenAI:
50+
def __init__(self, api_key):
51+
self.embeddings = _DummyEmbeddingsAPI()
52+
53+
54+
class _DummyGen:
55+
def __init__(self, *a, **k):
56+
pass
57+
58+
def generate(self, question, contexts):
59+
return f"answer:{question}:{len(contexts)}"
60+
61+
62+
@pytest.mark.integration
63+
def test_dense_and_hybrid_end_to_end(tmp_path, monkeypatch):
64+
# Settings
65+
db_path = tmp_path / "app.db"
66+
index_path = tmp_path / "idx.faiss"
67+
id_map_path = tmp_path / "id.pkl"
68+
69+
monkeypatch.setattr(settings, "sqlite_url", f"sqlite:///{db_path}", raising=False)
70+
monkeypatch.setattr(settings, "index_path", str(index_path), raising=False)
71+
monkeypatch.setattr(settings, "id_map_path", str(id_map_path), raising=False)
72+
monkeypatch.setattr(settings, "openai_api_key", "k", raising=False)
73+
monkeypatch.setattr(settings, "openai_embedding_model", "dummy-4", raising=False)
74+
monkeypatch.setattr(openai_embedder_mod, "_MODEL_DIM", {"dummy-4": 4}, raising=False)
75+
monkeypatch.setattr(openai_embedder_mod, "OpenAI", _DummyOpenAI, raising=True)
76+
77+
# DB setup
78+
engine = create_engine(settings.sqlite_url, connect_args={"check_same_thread": False})
79+
SessionLocal = sessionmaker(bind=engine, autocommit=False, autoflush=False)
80+
Base.metadata.create_all(bind=engine)
81+
82+
doc_repo = SqlDocumentStorage(session_factory=SessionLocal)
83+
history_repo = HistorySqlStorage(session_factory=SessionLocal)
84+
85+
embedder = OpenAIEmbedder(model="dummy-4")
86+
vec_repo = FaissVectorStorage(index_path=str(index_path), id_map_path=str(id_map_path), dim=embedder.dim)
87+
88+
etl = ETLService(doc_repo, vec_repo, embedder)
89+
ids = etl.ingest(["alpha alpha alpha", "zzzz zzzz zzzz"])
90+
assert len(list(ids)) == 2
91+
92+
dense = DenseFaissRetriever(embedder=embedder, faiss_index=vec_repo, doc_repo=doc_repo)
93+
docs, scores = dense.retrieve("alpha", k=1)
94+
assert len(docs) == 1
95+
assert docs[0].content.startswith("alpha")
96+
assert len(scores) == 1
97+
98+
# Hybrid: ensure sparse is wired and returns something too.
99+
corpus, doc_ids = get_corpus_and_ids(doc_repo)
100+
sparse = SparseBM25Retriever(documents=corpus, doc_ids=doc_ids, doc_repo=doc_repo)
101+
hybrid = HybridRetriever(dense=dense, sparse=sparse, alpha=0.5)
102+
103+
svc = RagService(retriever=hybrid, generator=_DummyGen(), history_storage=history_repo)
104+
resp = svc.ask("alpha", top_k=1)
105+
assert resp["answer"].startswith("answer:alpha:1")
106+
assert resp["docs"]

tests/integration/test_build_index_sparse.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,17 @@ def test_build_index_sparse(tmp_path, monkeypatch):
2020
monkeypatch.setattr(settings, "faq_csv", str(f), raising=False)
2121
monkeypatch.setattr(settings, "sqlite_url", f"sqlite:///{tmp_path}/app.db", raising=False)
2222

23-
# Just test that build_index runs without error and creates database
23+
# build_index should run without error and populate the DB.
2424
from local_rag_backend.scripts import build_index
2525

2626
importlib.reload(build_index)
27-
try:
28-
build_index.main()
29-
except Exception:
30-
# If it fails, at least check the database was created
31-
pass
27+
build_index.main()
3228

3329
# Check database was populated
3430
eng = create_engine(settings.sqlite_url, connect_args={"check_same_thread": False})
3531
Session = sessionmaker(bind=eng, autocommit=False, autoflush=False)
3632
Base.metadata.create_all(bind=eng)
3733
docs = SqlDocumentStorage(session_factory=Session).get_all_documents()
38-
assert len(docs) >= 0 # At least database exists
34+
assert len(docs) == 1
35+
assert "T" in docs[0].content
36+
assert "C" in docs[0].content

tests/integration/test_integration_faiss.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
1-
import pytest
2-
3-
# Skip if faiss is not available in the environment
4-
pytest.importorskip("faiss")
51

62
import numpy as np
73

0 commit comments

Comments
 (0)