Source code for string2string.misc.plotting_functions

"""
This module contains the functions for plotting and visualizing the results.
"""

# Matplotlib
import matplotlib
import matplotlib.pyplot as plt

# Plotly
import plotly.graph_objects as go
import plotly.express as px

# Other necessary packages
import numpy as np
import torch
from typing import List, Union, Tuple, Optional
Coordinate = Union[int, float]

# Plot the pairwise alignment between two strings (or lists of strings)
[docs]def plot_pairwise_alignment( seq1_pieces: Union[str, List[Union[str, int, float]], np.ndarray], seq2_pieces: Union[str, List[Union[str, int, float]], np.ndarray], alignment: List[Tuple[int, int]] = [], str2colordict: Optional[dict] = None, padding_factor: float = 1.4, linewidth: float = 1.5, border_to_box: float = 0.2, title: str = 'Pairwise Alignment', seq1_name: str = 'Seq 1', seq2_name: str = 'Seq 2', show: bool = True, save: bool = False, save_path: str = 'pairwise_alignment.png', save_dpi: int = 300, save_bbox_inches: str = 'tight', ): """ This function is designed to generate a plot that displays the alignment between two given lists of characters, strings, integers, or floats (or a numpy array). To create this plot, the function takes in the two lists and a list of tuples that specifies the alignment between the two lists. Arguments: seq1_pieces (Union[str, List[Union[str, int, float], np.ndarray]]): The pieces of the first string or list of strings. seq2_pieces (Union[str, List[Union[str, int, float], np.ndarray]]): The pieces of the second string or list of strings. alignment (List[Tuple[int, int]]): The pairwise alignment between the two strings. str2colordict: Optional[dict] = None: A dictionary of colors for each character/string in the union of the two strings. padding_factor (float, optional): The factor to use for the padding (default is 1.4). linewidth (float, optional): The linewidth to use for the alignment (default is 1.5). border_to_box (float, optional): The gap between the border and the box (default is 0.2). title (str, optional): The title of the plot (default is 'Pairwise Alignment'). seq1_name (str, optional): The name of the first sequence (default is 'Seq 1'). seq2_name (str, optional): The name of the second sequence (default is 'Seq 2'). show (bool, optional): Whether to show the plot (default is True). save (bool, optional): Whether to save the plot (default is False). save_path (str, optional): The path to save the plot (default is 'pairwise_alignment.png'). save_dpi (int, optional): The dpi to use for the plot (default is 300). save_bbox_inches (str, optional): The bbox_inches to use for the plot (default is 'tight'). Returns: None .. note:: The pairwise alignment is a list of tuples of the form (i, j) where i is the index of the character in the first string and j is the index of the character in the second string. """ # Raise an error if str1 and seq2_pieces are not of the same type if type(seq1_pieces) != type(seq2_pieces): raise TypeError('seq1_pieces and seq2_pieces must be of the same type.') # Raise an error if save is True and save_path is None if save and save_path is None: raise ValueError('Save path is not specified.') # Get the length of the strings len1 = len(seq1_pieces) len2 = len(seq2_pieces) # Get the maximum length max_len = max(len1, len2) # Get the maximum length of the elements in the strings str1 and seq2_pieces max_len_chr1 = max([len(str(x)) for x in seq1_pieces]) max_len_chr2 = max([len(str(x)) for x in seq2_pieces]) max_len_chr = max(max_len_chr1, max_len_chr2) if max_len_chr > 20: raise ValueError('The maximum length of the characters in the strings must be less than 20.') # Get the scaling factor factor = 0.5 + (max_len_chr // 10.0) * 0.5 # Get the x and y coordinates of the characters x_char = np.concatenate((np.arange(len1), np.arange(len2))) y_char = np.concatenate((np.zeros(len1), np.ones(len2))) # Get the characters chars = np.concatenate((np.array(list(seq1_pieces)), np.array(list(seq2_pieces)))) # Create the figure _, ax = plt.subplots(figsize=(2 * max_len * factor, 2 * 2)) # Get the alignment alignment = np.array(alignment) # Check if the alignment is not None and not empty if len(alignment) > 0: indices1 = alignment[:, 0] indices2 = alignment[:, 1] # Draw the alignment for i in range(len(indices1)): ax.plot([indices1[i], indices2[i]], [border_to_box, 1.-border_to_box], 'o-', color='#336699', linewidth=0.75, zorder=2) # Draw the characters/strings for i, char in enumerate(chars): # Get the color of the character if it is in the dictionary strip_char = char.strip() fc_color = str2colordict[strip_char] if (str2colordict is not None and strip_char in str2colordict) else (0.88, 0.94, 1.0, 1.0) ax.text( x_char[i], y_char[i], char, size=12, ha='center', va='center', bbox=dict(facecolor=fc_color, edgecolor='#336699', linewidth=linewidth, boxstyle=f'square,pad={padding_factor}', alpha=0.99, )) # Set the limits of the axes ax.set_xlim(-0.5, max_len - 0.5) ax.set_ylim(-0.5, 1.5) # Set the ticks of the axes ax.set_yticks([0, 1]) # Set the tick labels of the axes ax.set_yticklabels([seq1_name, seq2_name], fontsize=12)#, fontweight='bold') # Set the title of the axes ax.set_title(title, fontsize=14, fontweight='book') # Turn off the spines ax.spines[:].set_visible(False) # Turn off the ticks ax.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False) # tight layout plt.tight_layout() # Show the plot if show: plt.show() # Save the plot if save: plt.savefig(save_path, dpi=save_dpi, bbox_inches=save_bbox_inches)
# Plot a heatmap
[docs]def plot_heatmap( data: Union[List[List[Union[str, int, float]]], np.ndarray], title: str = 'Heatmap', x_label: str = 'X', y_label: str = 'Y', x_ticks: List[str] = None, y_ticks: List[str] = None, colorbar_kwargs: dict = None, color_threshold: float = None, textcolors=("black", "white"), valfmt="{x:.1f}", legend: bool = False, show: bool = True, save: bool = False, save_path: str = 'heatmap.png', save_dpi: int = 300, save_bbox_inches: str = 'tight', **kwargs) -> None: """ This function creates a heatmap visualization based on a given 2D array of data. The input array can represent a variety of data structures, such as a confusion matrix or a correlation matrix, and can be represented as a list of lists or a numpy array. The resulting plot will visually represent the data in the input array using a color-coded grid. Arguments: data (Union[List[List[Union[str, int, float]]], np.ndarray]): The data to plot. title (str, optional): The title of the plot (default: 'Heatmap'). x_label (str, optional): The label of the x-axis (default: 'X'). y_label (str, optional): The label of the y-axis (default: 'Y'). x_ticks (List[str], optional): The ticks of the x-axis (default: None). y_ticks (List[str], optional): The ticks of the y-axis (default: None). colorbar_kwargs (dict, optional): The keyword arguments for the colorbar (default: None). color_threshold (float, optional): The threshold to use for the color (default: None). textcolors (tuple, optional): The colors to use for the text (default: ("black", "white")). valfmt (str, optional): The format to use for the values (default: "{x:.1f}"). legend (bool, optional): Whether to show the legend (default: False). show (bool, optional): Whether to show the plot (default: True). save (bool, optional): Whether to save the plot (default: False). save_path (str, optional): The path to save the plot (default: 'heatmap.png'). save_dpi (int, optional): The dpi to use for the plot (default: 300). save_bbox_inches (str, optional): The bbox_inches to use for the plot (default: 'tight'). **kwargs: The keyword arguments for the heatmap. """ # Create the figure and axes fig, ax = plt.subplots() # Create the heatmap im = ax.imshow(data, **kwargs) # Create the colorbar if colorbar_kwargs is None: colorbar_kwargs = {} # Create the colorbar if legend: ax.figure.colorbar(im, ax=ax, **colorbar_kwargs) # Set the x-axis label ax.set_xlabel(x_label, fontweight='medium') # Set the y-axis label ax.set_ylabel(y_label, fontweight='medium') # Set the x-axis ticks if x_ticks is not None: ax.set_xticks(np.arange(len(x_ticks))) ax.set_xticklabels(x_ticks) # Set the y-axis ticks if y_ticks is not None: ax.set_yticks(np.arange(len(y_ticks))) ax.set_yticklabels(y_ticks) # Set the tick parameters of the axes ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False) # Rotate the tick labels and set their alignment. plt.setp(ax.get_xticklabels(), ha="center", rotation_mode="anchor") # Turn off the spines ax.spines[:].set_visible(False) # Turn off the ticks ax.set_xticks(np.arange(data.shape[1]+1)-.5, minor=True) ax.set_yticks(np.arange(data.shape[0]+1)-.5, minor=True) # Set the grid and tick parameters ax.grid(which="minor", color="w", linestyle='-', linewidth=3) ax.tick_params(which="minor", bottom=False, left=False) ax.xaxis.set_label_position('top') # Color threshold for the heatmap if color_threshold is not None: color_threshold = im.norm(color_threshold) else: color_threshold = im.norm(data.max())/2. # Text annotations kw = dict(horizontalalignment="center", verticalalignment="center") # Value format for the text annotations if isinstance(valfmt, str): valfmt = matplotlib.ticker.StrMethodFormatter(valfmt) # Loop over the data and create text annotations for i in range(data.shape[0]): for j in range(data.shape[1]): kw.update(color=textcolors[int(im.norm(data[i, j]) > color_threshold)]) im.axes.text(j, i, valfmt(data[i, j], None), **kw) # Set the title fig.suptitle(title, fontsize=14, fontweight='semibold') # Set the tight layout plt.tight_layout() # Show the plot if show: plt.show() # Save the plot if save: plt.savefig(save_path, dpi=save_dpi, bbox_inches=save_bbox_inches)
[docs]def plot_corpus_embeds_with_plotly( corpus_embeddings: Union[List[List[Coordinate]], np.ndarray, torch.Tensor], corpus_labels: List[str], corpus_hover_texts: List[str], corpus_scatter_kwargs: Optional[dict] = {}, layoot_dict: Optional[dict] = None, query_embeddings: Optional[Union[List[List[Coordinate]], np.ndarray]] = None, query_labels: Optional[List[str]] = None, query_hover_texts: List[str] = None, query_modes: Optional[Union[List[str], str]] = 'markers', query_marker_dict: Optional[dict] = None, show_plot: bool = True, save_path: Optional[str] = None, ) -> go.Figure: """ The purpose of this function is to generate a 2D scatter plot using plotly, based on a given corpus of embeddings and their corresponding labels. The function takes in the embeddings and labels as input, and plots them in the scatter plot with each point represented by a particular color and shape based on its label. Additionally, the function can also take in a query embedding and its corresponding label as optional inputs, which will be plotted separately on the scatter plot with a distinct color and shape. Arguments: corpus_embeddings: A list of lists or a numpy array or a torch tensor of corpus embeddings (e.g. sentence embeddings). corpus_labels: A list of labels for the corpus embeddings. corpus_hover_texts: A list of hover texts for the corpus embeddings. corpus_scatter_kwargs: A dictionary of keyword arguments for the corpus scatter plot (e.g. marker size, marker color, etc.) (default: {}). layoot_dict: A dictionary of keyword arguments for the layout of the plot (e.g. title, x-axis title, y-axis title, etc.) (default: None). query_embeddings: A list of lists or a numpy array of query embeddings (e.g. sentence embeddings) (default: None). query_labels: A list of labels for the query embeddings (default: None). query_hover_texts: A list of hover texts for the query embeddings (default: None). query_modes: A list of modes for the query embeddings (default: 'markers'). query_marker_dict: A dictionary of keyword arguments for the query scatter plot (e.g. marker size, marker color, etc.) (default: None). show_plot: A boolean whether to show the plot (default: True). save_path: A string of the path to save the plot (e.g., 'corpus_embeddings.html') (default: None). Returns: go.Figure: A plotly figure object. .. note:: Please refer to the Hands-on Tutorial on Semantic Search with HUPD Patent Data for a good demonstration of how to use this function. """ # If the corpus_embeddings are a torch tensor or a list, we convert them to a numpy array if isinstance(corpus_embeddings, torch.Tensor): corpus_embeddings = corpus_embeddings.detach().cpu().numpy() elif isinstance(corpus_embeddings, list): corpus_embeddings = np.array(corpus_embeddings) # Let us plot the corpus embeddings fig = px.scatter(corpus_embeddings, x=0, y=1, color=corpus_labels, hover_name=corpus_hover_texts, **corpus_scatter_kwargs) # If we have query embeddings, we plot them as well if query_embeddings is not None: # If the query_embeddings are a torch tensor or a list, we convert them to a numpy array if isinstance(query_embeddings, torch.Tensor): query_embeddings = query_embeddings.detach().cpu().numpy() elif isinstance(query_embeddings, list): query_embeddings = np.array(query_embeddings) # Check if markers are specified for the query embeddings q_marker_dict = query_marker_dict if query_marker_dict is not None else dict(size=10, color='black') # If the query_modes is a string, we convert it to a list of the same length as the query_embeddings if isinstance(query_modes, str): query_modes = [query_modes] * len(query_embeddings) # Let us plot the query embeddings on top of the corpus embeddings, one by one for i, query_embedding in enumerate(query_embeddings): q_mode = query_modes[i] if query_modes is not None else 'markers' q_label = query_labels[i] if query_labels is not None else 'Query' q_hover_text = query_hover_texts[i] if query_hover_texts is not None else 'Query' fig.add_trace(go.Scatter(x=[query_embedding[0]], y=[query_embedding[1]], mode=q_mode, marker=q_marker_dict, name=q_label, hovertext=q_hover_text)) # If we have a layout dictionary, we update the figure layout with it if layoot_dict is not None: fig.update_layout(layoot_dict) # If we want to save the plot, we do it here if save_path is not None: fig.write_html(save_path) # If we want to show the plot, we do it here if show_plot: fig.show() return fig