dataflow-analysis/batch_analysis_analysis.py

76 lines
2.7 KiB
Python
Raw Normal View History

2025-06-10 14:36:56 +00:00
import pandas as pd
import numpy as np
import argparse
import seaborn as sns
import matplotlib.pyplot as plt
def parse_arguments():
parser = argparse.ArgumentParser(description='Analyze chain data from CSV file.')
parser.add_argument('--input', '-i', required=True, help='Path to the input CSV file')
return parser.parse_args()
def main():
args = parse_arguments()
# Load the CSV file from the input argument
df = pd.read_csv(args.input)
# Extract the experiment_name which should be the same across all rows
if 'experiment_name' not in df.columns:
raise ValueError("Input CSV must contain 'experiment_name' column.")
experiment_name = df['experiment_name'].iloc[0]
# Strip timestamp from experiment_name if it exists
experiment_name = experiment_name.split('-')[0] if '-' in experiment_name else experiment_name
# Group data by chain
chain_groups = df.groupby('chain')
# For each chain, create a plot with four boxplots (mean, std, min, max)
for chain_name, chain_data in chain_groups:
# Create a figure for this chain
plt.figure(figsize=(12, 8))
# Normalize chain name for filename
chain_name_fs = str(chain_name).replace('--> /', '-').replace('/', '_').replace(' ', '')
# Create a DataFrame with the columns we want to plot
plot_data = pd.DataFrame({
'Mean': chain_data['mean'],
'Std': chain_data['std'],
'Min': chain_data['min'],
'Max': chain_data['max']
})
# Create boxplots
ax = sns.boxplot(data=plot_data, palette='Set3')
# Add individual data points
sns.stripplot(data=plot_data, color='black', alpha=0.5, size=4, jitter=True)
# Set labels and title
plt.title(f'Statistics for Chain: {chain_name}\nAcross {len(chain_data)} Experiment Runs\n{experiment_name}', fontsize=14)
plt.ylabel('Latency (ms)', fontsize=12)
plt.xlabel('Statistic Type', fontsize=12)
# Add grid for better readability
plt.grid(axis='y', linestyle='--', alpha=0.7)
# Tighten layout and save the figure
plt.tight_layout()
output_file = args.input.replace('.csv', f'_chain_{chain_name_fs}_analysis.png')
plt.savefig(output_file, dpi=300)
plt.close()
# Also calculate and print summary statistics for this chain
summary = chain_data.describe()
print(f"\nSummary for chain: {chain_name}")
print(summary[['mean', 'std', 'min', 'max']])
print(f"\nAnalysis complete. Plots saved with base name: {args.input.replace('.csv', '_chain_*_analysis.png')}")
if __name__ == "__main__":
main()