Source code for rllm.preprocessing.word_embedding
from dataclasses import dataclass
from typing import Callable, Optional
from pandas import Series
import torch
from torch import Tensor
from tqdm import tqdm
[docs]
@dataclass
class TextEmbedderConfig:
"""Configuration for text embedding in preprocessing pipelines.
It defines the embedding callable and optional mini-batch size used during inference.
Args:
text_embedder (Callable[[list[str]], Tensor]): Callable that maps a
batch of strings to embeddings.
batch_size (Optional[int]): Mini-batch size for embedding. If ``None``,
all samples are embedded in one call.
"""
text_embedder: Callable[[list[str]], Tensor]
batch_size: Optional[int] = None
[docs]
def embed_text_column(
col_series: Series,
config: TextEmbedderConfig,
) -> Tensor:
r"""Embed a text column into dense vector representations.
The function supports both one-shot and mini-batch embedding, depending on configuration.
Args:
col_series (Series): Input text column.
config (TextEmbedderConfig): Embedding configuration.
Returns:
Tensor: Embedded features with shape :math:`(N, D)` and dtype
``torch.float32``.
"""
embedder = config.text_embedder
batch_size = config.batch_size
assert embedder is not None, "Need an embedder for text column!"
col_str = col_series.astype(str)
col_list = col_str.to_list()
if batch_size is None:
embeddings = embedder(col_list)
else:
emb_list: list[Tensor] = []
for i in tqdm(
range(0, len(col_list), batch_size), desc="Embedding raw data in mini-batch"
):
emb = embedder(col_list[i : i + batch_size])
emb_list.append(emb)
embeddings = torch.cat(emb_list, dim=0)
return embeddings.float()