import pandas as pd import numpy as np import argparse import seaborn as sns import matplotlib.pyplot as plt def parse_arguments(): parser = argparse.ArgumentParser(description='Analyze chain data from CSV file.') parser.add_argument('--input', '-i', required=True, help='Path to the input CSV file') return parser.parse_args() def main(): args = parse_arguments() # Load the CSV file from the input argument df = pd.read_csv(args.input) # Extract the experiment_name which should be the same across all rows if 'experiment_name' not in df.columns: raise ValueError("Input CSV must contain 'experiment_name' column.") experiment_name = df['experiment_name'].iloc[0] # Strip timestamp from experiment_name if it exists experiment_name = experiment_name.split('-')[0] if '-' in experiment_name else experiment_name # Group data by chain chain_groups = df.groupby('chain') # For each chain, create a plot with four boxplots (mean, std, min, max) for chain_name, chain_data in chain_groups: # Create a figure for this chain plt.figure(figsize=(12, 8)) # Normalize chain name for filename chain_name_fs = str(chain_name).replace('--> /', '-').replace('/', '_').replace(' ', '') # Create a DataFrame with the columns we want to plot plot_data = pd.DataFrame({ 'Mean': chain_data['mean'], 'Std': chain_data['std'], 'Min': chain_data['min'], 'Max': chain_data['max'] }) # Create boxplots ax = sns.boxplot(data=plot_data, palette='Set3') # Add individual data points sns.stripplot(data=plot_data, color='black', alpha=0.5, size=4, jitter=True) # Set labels and title plt.title(f'Statistics for Chain: {chain_name}\nAcross {len(chain_data)} Experiment Runs\n{experiment_name}', fontsize=14) plt.ylabel('Latency (ms)', fontsize=12) plt.xlabel('Statistic Type', fontsize=12) # Add grid for better readability plt.grid(axis='y', linestyle='--', alpha=0.7) # Tighten layout and save the figure plt.tight_layout() output_file = args.input.replace('.csv', f'_chain_{chain_name_fs}_analysis.png') plt.savefig(output_file, dpi=300) plt.close() # Also calculate and print summary statistics for this chain summary = chain_data.describe() print(f"\nSummary for chain: {chain_name}") print(summary[['mean', 'std', 'min', 'max']]) print(f"\nAnalysis complete. Plots saved with base name: {args.input.replace('.csv', '_chain_*_analysis.png')}") if __name__ == "__main__": main()