"""Embedder module for generating document embeddings.
This module provides the Embedder class for generating embeddings from text chunks
using sentence-transformers models with batch processing support.
"""
import logging
import os
from typing import Any
from sentence_transformers import SentenceTransformer
from thoth.shared.utils.logger import setup_logger
logger = setup_logger(__name__)
[docs]
class Embedder:
"""Generate embeddings from text using sentence-transformers.
Supports batch processing with progress tracking for efficient embedding generation.
"""
[docs]
def __init__(
self,
model_name: str = "all-MiniLM-L6-v2",
device: str | None = None,
batch_size: int = 32,
logger_instance: logging.Logger | logging.LoggerAdapter | None = None,
):
"""Initialize the Embedder with a sentence-transformers model.
Args:
model_name: Name of the sentence-transformers model to use.
Default is 'all-MiniLM-L6-v2' for a good balance of speed and quality.
Other options: 'all-mpnet-base-v2' (higher quality, slower).
device: Device to use for inference ('cuda', 'cpu', or None for auto-detect).
batch_size: Number of texts to process in each batch (default: 32).
logger_instance: Optional logger instance to use.
"""
self.model_name = model_name
self.batch_size = batch_size
self.logger = logger_instance or logger
# HuggingFace token: required for gated models; sentence-transformers reads HUGGING_FACE_HUB_TOKEN.
hf_token = os.getenv("HF_TOKEN")
if hf_token:
os.environ["HUGGING_FACE_HUB_TOKEN"] = hf_token
self.logger.info(f"Loading embedding model: {model_name}")
# First load downloads from HuggingFace Hub if not cached; device=None auto-selects CUDA/CPU.
self.model = SentenceTransformer(model_name, device=device)
self.logger.info(f"Model loaded successfully on device: {self.model.device}")
[docs]
def embed(
self,
texts: list[str],
show_progress: bool = False,
normalize: bool = True,
) -> list[list[float]]:
"""Generate embeddings for a list of texts.
Args:
texts: List of text strings to embed.
show_progress: Whether to show a progress bar during batch processing.
normalize: Whether to normalize embeddings to unit length (default: True).
Normalized embeddings work better with cosine similarity.
Returns:
List of embedding vectors, where each vector is a list of floats.
Raises:
ValueError: If texts list is empty or contains empty/whitespace-only strings.
"""
if not texts:
msg = "Cannot generate embeddings for empty text list"
raise ValueError(msg)
invalid_indices = [i for i, text in enumerate(texts) if not isinstance(text, str) or not text.strip()]
if invalid_indices:
msg = (
"Cannot generate embeddings for empty or whitespace-only texts; "
f"invalid entries at indices: {invalid_indices}"
)
raise ValueError(msg)
self.logger.info(f"Generating embeddings for {len(texts)} texts with batch_size={self.batch_size}")
# Generate embeddings with batch processing
embeddings = self.model.encode(
texts,
batch_size=self.batch_size,
show_progress_bar=show_progress,
normalize_embeddings=normalize,
convert_to_numpy=True,
)
# Convert numpy arrays to lists for JSON serialization
embeddings_list: list[list[float]] = embeddings.tolist()
self.logger.info(f"Generated {len(embeddings_list)} embeddings of dimension {len(embeddings_list[0])}")
return embeddings_list
[docs]
def embed_single(self, text: str, normalize: bool = True) -> list[float]:
"""Generate embedding for a single text.
Args:
text: Text string to embed.
normalize: Whether to normalize embedding to unit length (default: True).
Returns:
Embedding vector as a list of floats.
Raises:
ValueError: If text is empty.
"""
if not text:
msg = "Cannot generate embedding for empty text"
raise ValueError(msg)
embeddings = self.embed([text], show_progress=False, normalize=normalize)
return embeddings[0]
[docs]
def get_embedding_dimension(self) -> int:
"""Get the dimension of embeddings produced by this model.
Returns:
Integer dimension of the embedding vectors.
"""
return self.model.get_sentence_embedding_dimension() # type: ignore[return-value]
[docs]
def get_model_info(self) -> dict[str, Any]:
"""Get information about the loaded model.
Returns:
Dictionary containing model metadata:
- model_name: Name of the model
- embedding_dimension: Dimension of embeddings
- max_seq_length: Maximum sequence length the model can handle
- device: Device the model is running on
- batch_size: Configured batch size for processing
"""
return {
"model_name": self.model_name,
"embedding_dimension": self.get_embedding_dimension(),
"max_seq_length": self.model.max_seq_length,
"device": str(self.model.device),
"batch_size": self.batch_size,
}