import pandas as pd import argparse import seaborn as sns import matplotlib.pyplot as plt def parse_arguments(): parser = argparse.ArgumentParser(description='Analyze chain data from CSV file.') parser.add_argument('--input', '-i', required=True, help='Path to the input CSV file') return parser.parse_args() def main(): args = parse_arguments() # Load the CSV file from the input argument df = pd.read_csv(args.input) # Extract the experiment_name which should be the same across all rows if 'experiment_name' not in df.columns: raise ValueError("Input CSV must contain 'experiment_name' column.") experiment_name = df['experiment_name'].iloc[0] # Strip timestamp from experiment_name if it exists experiment_name = experiment_name.split('-')[0] if '-' in experiment_name else experiment_name # Group data by chain chain_groups = df.groupby('chain') # For each chain, create a figure with five subplots for boxplots (mean, std, min, max, count) for chain_name, chain_data in chain_groups: fig, axs = plt.subplots(1, 5, figsize=(18, 6), constrained_layout=True) # Normalize chain name for filename chain_name_fs = str(chain_name).replace('--> /', '-').replace('/', '_').replace(' ', '') # Create a DataFrame with the columns we want to plot plot_data = chain_data[['mean', 'std', 'min', 'max', 'count']].copy() plot_data.columns = ['Mean', 'Std', 'Min', 'Max', 'Count'] # Make all plots have the same color palette palette = sns.color_palette("husl", 4) # Add a distinct color for the 'Count' plot, as it is a different metric colors = palette + ['lightcoral'] for idx, (col, color) in enumerate(zip(['Mean', 'Std', 'Min', 'Max', 'Count'], colors)): ax = axs[idx] # Prepare the data for the current column current_plot_data = plot_data[col].dropna() # Remove outliers for better visualization filtered_plot_data = current_plot_data[current_plot_data.between(current_plot_data.quantile(.03), current_plot_data.quantile(0.97))] filtered_count = current_plot_data.count() - filtered_plot_data.count() # Create boxplots sns.boxplot(data=filtered_plot_data, ax=ax, color=color, showfliers=False, width=0.4) # type: ignore # Add individual data points sns.swarmplot(data=filtered_plot_data, ax=ax, color='black', size=3, alpha=0.6) # type: ignore # Set labels and title ax.set_title(f'{col} Distribution', fontsize=14, fontweight='bold') ax.set_xticks([]) # Remove x-ticks for clarity ax.set_xlabel('') # No x-label needed ax.set_ylabel('Latency (ms)' if col != 'Count' else 'Count', fontsize=12) # Calculate statistics of the statistics - here based on the original data with outliers! data_values = plot_data[col] first_line_length = len(f"Mean: {data_values.mean():.2f}") second_line_length = len(f"Std: {data_values.std():.2f}") third_line_length = len(f"Min: {data_values.min():.2f}") fourth_line_length = len(f"Max: {data_values.max():.2f}") fivth_line_length = len(f"Filtered: {filtered_count}") max_length = max(first_line_length, second_line_length, third_line_length, fourth_line_length, fivth_line_length) + 1 # Prepare the text for the legend stats_text = ( f"Mean:{' ' * (max_length - first_line_length)}{data_values.mean():.2f}\n" f"Std:{' ' * (max_length - second_line_length)}{data_values.std():.2f}\n" f"Min:{' ' * (max_length - third_line_length)}{data_values.min():.2f}\n" f"Max:{' ' * (max_length - fourth_line_length)}{data_values.max():.2f}\n" f"Filtered:{' ' * (max_length - fivth_line_length)}{filtered_count}" ) # --- Place legend in the top right using axes fraction coordinates --- ax.text( 0.95, 0.98, # axes fraction: 95% right, 98% up stats_text, transform=ax.transAxes, verticalalignment='top', horizontalalignment='right', fontsize=10, fontfamily='monospace', bbox=dict(facecolor='white', alpha=0.9, boxstyle='round,pad=0.3', edgecolor='gray') ) # Add grid for better readability ax.grid(axis='y', linestyle='--', alpha=0.4) # Set the overall title for the figure plt.suptitle( f'Statistics for Chain: {chain_name}\nAcross {len(chain_data)} Experiment Runs - {experiment_name}', fontsize=18, fontweight='bold' ) # Save the figure with a filename that includes the chain name output_file = args.input.replace('.csv', f'_chain_{chain_name_fs}_analysis.png') plt.savefig(output_file, dpi=300) plt.close() # Print summary statistics for the chain summary = chain_data.describe() print(f"\nSummary for chain: {chain_name}") print(summary[['mean', 'std', 'min', 'max', 'count']]) print(f"\nAnalysis complete. Plots saved with base name: {args.input.replace('.csv', '_chain_*_analysis.png')}") if __name__ == "__main__": main()