76 lines
No EOL
2.7 KiB
Python
76 lines
No EOL
2.7 KiB
Python
import pandas as pd
|
|
import numpy as np
|
|
import argparse
|
|
import seaborn as sns
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
def parse_arguments():
|
|
parser = argparse.ArgumentParser(description='Analyze chain data from CSV file.')
|
|
parser.add_argument('--input', '-i', required=True, help='Path to the input CSV file')
|
|
return parser.parse_args()
|
|
|
|
|
|
def main():
|
|
args = parse_arguments()
|
|
|
|
# Load the CSV file from the input argument
|
|
df = pd.read_csv(args.input)
|
|
|
|
# Extract the experiment_name which should be the same across all rows
|
|
if 'experiment_name' not in df.columns:
|
|
raise ValueError("Input CSV must contain 'experiment_name' column.")
|
|
experiment_name = df['experiment_name'].iloc[0]
|
|
|
|
# Strip timestamp from experiment_name if it exists
|
|
experiment_name = experiment_name.split('-')[0] if '-' in experiment_name else experiment_name
|
|
|
|
# Group data by chain
|
|
chain_groups = df.groupby('chain')
|
|
|
|
# For each chain, create a plot with four boxplots (mean, std, min, max)
|
|
for chain_name, chain_data in chain_groups:
|
|
# Create a figure for this chain
|
|
plt.figure(figsize=(12, 8))
|
|
|
|
# Normalize chain name for filename
|
|
chain_name_fs = str(chain_name).replace('--> /', '-').replace('/', '_').replace(' ', '')
|
|
|
|
# Create a DataFrame with the columns we want to plot
|
|
plot_data = pd.DataFrame({
|
|
'Mean': chain_data['mean'],
|
|
'Std': chain_data['std'],
|
|
'Min': chain_data['min'],
|
|
'Max': chain_data['max']
|
|
})
|
|
|
|
# Create boxplots
|
|
ax = sns.boxplot(data=plot_data, palette='Set3')
|
|
|
|
# Add individual data points
|
|
sns.stripplot(data=plot_data, color='black', alpha=0.5, size=4, jitter=True)
|
|
|
|
# Set labels and title
|
|
plt.title(f'Statistics for Chain: {chain_name}\nAcross {len(chain_data)} Experiment Runs\n{experiment_name}', fontsize=14)
|
|
plt.ylabel('Latency (ms)', fontsize=12)
|
|
plt.xlabel('Statistic Type', fontsize=12)
|
|
|
|
# Add grid for better readability
|
|
plt.grid(axis='y', linestyle='--', alpha=0.7)
|
|
|
|
# Tighten layout and save the figure
|
|
plt.tight_layout()
|
|
output_file = args.input.replace('.csv', f'_chain_{chain_name_fs}_analysis.png')
|
|
plt.savefig(output_file, dpi=300)
|
|
plt.close()
|
|
|
|
# Also calculate and print summary statistics for this chain
|
|
summary = chain_data.describe()
|
|
print(f"\nSummary for chain: {chain_name}")
|
|
print(summary[['mean', 'std', 'min', 'max']])
|
|
|
|
print(f"\nAnalysis complete. Plots saved with base name: {args.input.replace('.csv', '_chain_*_analysis.png')}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |