Source code for string2string.search.faiss_search

"""
This module contains a wrapper for the Faiss library by Facebook AI Research.
"""

from typing import List, Union, Optional, Dict, Any
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import torch
from transformers import AutoTokenizer, AutoModel
from datasets import Dataset

import pandas as pd

# FAISS library wrapper class
[docs]class FaissSearch:
[docs] def __init__(self, model_name_or_path: str = 'facebook/bart-large', tokenizer_name_or_path: str = 'facebook/bart-large', device: str = 'cpu', ) -> None: r""" This function initializes the wrapper for the FAISS library, which is used to perform semantic search. .. attention:: * If you use this class, please make sure to cite the following paper: .. code-block:: latex @article{johnson2019billion, title={Billion-scale similarity search with {GPUs}}, author={Johnson, Jeff and Douze, Matthijs and J{\'e}gou, Herv{\'e}}, journal={IEEE Transactions on Big Data}, volume={7}, number={3}, pages={535--547}, year={2019}, publisher={IEEE} } * The code is based on the following GitHub repository: https://github.com/facebookresearch/faiss Arguments: model_name_or_path (str, optional): The name or path of the model to use. Defaults to 'facebook/bart-large'. tokenizer_name_or_path (str, optional): The name or path of the tokenizer to use. Defaults to 'facebook/bart-large'. device (str, optional): The device to use. Defaults to 'cpu'. Returns: None """ # Set the device self.device = device # If the tokenizer is not specified, use the model name or path if tokenizer_name_or_path is None: tokenizer_name_or_path = model_name_or_path # Load the tokenizer self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) # Load the model self.model = AutoModel.from_pretrained(model_name_or_path).to(self.device) # Set the model to evaluation mode (since we do not need the gradients) self.model.eval() # Initialize the dataset self.dataset = None
# Auxiliary function to get the last hidden state
[docs] def get_last_hidden_state(self, embeddings: torch.Tensor, ) -> torch.Tensor: """ This function returns the last hidden state (e.g., [CLS] token's) of the input embeddings. Arguments: embeddings (torch.Tensor): The input embeddings. Returns: torch.Tensor: The last hidden state. """ # Get the last hidden state last_hidden_state = embeddings.last_hidden_state # Return the last hidden state return last_hidden_state[:, 0, :]
# Auxiliary function to get the mean pooling
[docs] def get_mean_pooling(self, embeddings: torch.Tensor, ) -> torch.Tensor: """ This function returns the mean pooling of the input embeddings. Arguments: embeddings (torch.Tensor): The input embeddings. Returns: torch.Tensor: The mean pooling. """ # Get the mean pooling mean_pooling = embeddings.last_hidden_state.mean(dim=1) # Return the mean pooling return mean_pooling
# Get the embeddings
[docs] def get_embeddings(self, text: Union[str, List[str]], embedding_type: str = 'last_hidden_state', batch_size: int = 8, num_workers: int = 4, ) -> torch.Tensor: """ This function returns the embeddings of the input text. Arguments: text (Union[str, List[str]]): The input text. embedding_type (str, optional): The type of embedding to use. Defaults to 'last_hidden_state'. batch_size (int, optional): The batch size to use. Defaults to 8. num_workers (int, optional): The number of workers to use. Defaults to 4. Returns: torch.Tensor: The embeddings. Raises: ValueError: If the embedding type is invalid. """ # Check if the embedding type is valid if embedding_type not in ['last_hidden_state', 'mean_pooling']: raise ValueError(f'Invalid embedding type: {embedding_type}. Only "last_hidden_state" and "mean_pooling" are supported.') # Tokenize the input text encoded_text = self.tokenizer( text, padding=True, truncation=True, return_tensors='pt', ) # Move the input text to the device encoded_text = encoded_text.to(self.device) # encoded_inputs = {k: v.to(self.device) for k, v in encoded_inputs.items()} # Get the embeddings with torch.no_grad(): embeddings = self.model(**encoded_text) # Get the proper embedding type if embedding_type == 'last_hidden_state': # Get the last hidden state embeddings = self.get_last_hidden_state(embeddings) elif embedding_type == 'mean_pooling': # Get the mean pooling embeddings = self.get_mean_pooling(embeddings) # Return the embeddings return embeddings
# Add FAISS index
[docs] def add_faiss_index(self, column_name: str = 'embeddings', metric_type: Optional[int] = None, batch_size: int = 8, **kwargs, ) -> None: """ This function adds a FAISS index to the dataset. Arguments: column_name (str, optional): The name of the column containing the embeddings. Defaults to 'embeddings'. index_type (str, optional): The index type to use. Defaults to 'Flat'. metric_type (str, optional): The metric type to use. Defaults to 'L2'. Returns: None Raises: ValueError: If the dataset is not initialized. """ # Check if the dataset is initialized if self.dataset is None: raise ValueError('The dataset is not initialized. Please initialize the dataset first.') print('Adding FAISS index...') self.dataset.add_faiss_index( column_name, # metric_type=metric_type, # device=self.device, # batch_size=batch_size, faiss_verbose=True, # **kwargs, )
[docs] def save_faiss_index(self, index_name: str, file_path: str, ) -> None: """ This function saves the FAISS index to the specified file path. * This is a wrapper function for the `save_faiss_index` function in the `Dataset` class. Arguments: index_name (str): The name of the FAISS index (e.g., "embeddings") file_path (str): The file path to save the FAISS index. Returns: None Raises: ValueError: If the dataset is not initialized. """ # Check if the dataset is initialized if self.dataset is None: raise ValueError('The dataset is not initialized. Please initialize the dataset first.') print('Saving FAISS index...') self.dataset.save_faiss_index(index_name=index_name, file=file_path)
[docs] def load_faiss_index(self, index_name: str, file_path: str, device: str = 'cpu', ) -> None: """ This function loads the FAISS index from the specified file path. * This is a wrapper function for the `load_faiss_index` function in the `Dataset` class. Arguments: index_name (str): The name of the FAISS index (e.g., "embeddings") file_path (str): The file path to load the FAISS index from. device (str, optional): The device to use ("cpu" or "cuda") (default: "cpu"). Returns: None Raises: ValueError: If the dataset is not initialized. """ # Check if the dataset is initialized if self.dataset is None: raise ValueError('The dataset is not initialized. Please initialize the dataset first.') print('Loading FAISS index...') self.dataset.load_faiss_index(index_name=index_name, file=file_path, device=device)
# Initialize the corpus using a dictionary or pandas DataFrame or HuggingFace Datasets object
[docs] def initialize_corpus(self, corpus: Union[Dict[str, List[str]], pd.DataFrame, Dataset], section: str = 'text', index_column_name: str = 'embeddings', embedding_type: str = 'last_hidden_state', batch_size: Optional[int] = None, num_workers: Optional[int] = None, save_path: Optional[str] = None, ) -> Dataset: """ This function initializes a dataset using a dictionary or pandas DataFrame or HuggingFace Datasets object. Arguments: dataset_dict (Dict[str, List[str]]): The dataset dictionary. section (str): The section of the dataset to use whose embeddings will be used for semantic search (e.g., 'text', 'title', etc.) (default: 'text'). index_column_name (str): The name of the column containing the embeddings (default: 'embeddings') embedding_type (str): The type of embedding to use (default: 'last_hidden_state'). batch_size (int, optional): The batch size to use (default: 8). max_length (int, optional): The maximum length of the input sequences. num_workers (int, optional): The number of workers to use. save_path (Optional[str], optional): The path to save the dataset (default: None). Returns: Dataset: The dataset object (HuggingFace Datasets). Raises: ValueError: If the dataset is not a dictionary or pandas DataFrame or HuggingFace Datasets object. """ # Create the dataset if isinstance(corpus, dict): self.dataset = Dataset.from_dict(corpus) elif isinstance(corpus, pd.DataFrame): self.dataset = Dataset.from_pandas(corpus) elif isinstance(corpus, Dataset): self.dataset = corpus else: raise ValueError('The dataset must be a dictionary or pandas DataFrame.') # Set the embedding_type self.embedding_type = embedding_type # Tokenize the dataset # self.dataset = self.dataset.map( # lambda x: x[section], # batched=True, # batch_size=batch_size, # num_proc=num_workers, # ) # Map the section of the dataset to the embeddings self.dataset = self.dataset.map( lambda x: { index_column_name: self.get_embeddings(x[section], embedding_type=self.embedding_type).detach().cpu().numpy()[0] }, # batched=True, batch_size=batch_size, num_proc=num_workers, ) # Save the dataset if save_path is not None: self.dataset.to_json(save_path) # Add FAISS index self.add_faiss_index( column_name=index_column_name, ) # Return the dataset return self.dataset
# Initialize the dataset using a JSON file
[docs] def load_dataset_from_json(self, json_path: str, ) -> Dataset: """ This function loads a dataset from a JSON file. Arguments: json_path (str): The path to the JSON file. Returns: Dataset: The dataset. """ # Load the dataset self.dataset = Dataset.from_json(json_path) # Return the dataset return self.dataset
# Search for the most similar elements in the dataset, given a query
[docs] def search(self, query: str, k: int = 1, index_column_name: str = 'embeddings', ) -> pd.DataFrame: """ This function searches for the most similar elements in the dataset, given a query. Arguments: query (str): The query. k (int, optional): The number of elements to return (default: 1). index_column_name (str, optional): The name of the column containing the embeddings (default: 'embeddings') Returns: pd.DataFrame: The most similar elements in the dataset (text, score, etc.), sorted by score. Remarks: The returned elements are dictionaries containing the text and the score. """ # Get the embeddings of the query query_embeddings = self.get_embeddings([query], embedding_type=self.embedding_type).detach().cpu().numpy() # Search for the most similar elements in the dataset scores, similar_elts = self.dataset.get_nearest_examples( index_name=index_column_name, query=query_embeddings, k=k, ) # Convert the results to a pandas DataFrame results_df = pd.DataFrame.from_dict(similar_elts) # Add the scores results_df['score'] = scores # Sort the results by score results_df.sort_values("score", ascending=True, inplace=True) # Return the most similar elements return results_df