dataflow-analysis/batch_analysis_analysis.py

118 lines
No EOL
5.3 KiB
Python

import pandas as pd
import argparse
import seaborn as sns
import matplotlib.pyplot as plt
def parse_arguments():
parser = argparse.ArgumentParser(description='Analyze chain data from CSV file.')
parser.add_argument('--input', '-i', required=True, help='Path to the input CSV file')
return parser.parse_args()
def main():
args = parse_arguments()
# Load the CSV file from the input argument
df = pd.read_csv(args.input)
# Extract the experiment_name which should be the same across all rows
if 'experiment_name' not in df.columns:
raise ValueError("Input CSV must contain 'experiment_name' column.")
experiment_name = df['experiment_name'].iloc[0]
# Strip timestamp from experiment_name if it exists
experiment_name = experiment_name.split('-')[0] if '-' in experiment_name else experiment_name
# Group data by chain
chain_groups = df.groupby('chain')
# For each chain, create a figure with five subplots for boxplots (mean, std, min, max, count)
for chain_name, chain_data in chain_groups:
fig, axs = plt.subplots(1, 5, figsize=(18, 6), constrained_layout=True)
# Normalize chain name for filename
chain_name_fs = str(chain_name).replace('--> /', '-').replace('/', '_').replace(' ', '')
# Create a DataFrame with the columns we want to plot
plot_data = chain_data[['mean', 'std', 'min', 'max', 'count']].copy()
plot_data.columns = ['Mean', 'Std', 'Min', 'Max', 'Count']
# Make all plots have the same color palette
palette = sns.color_palette("husl", 4)
# Add a distinct color for the 'Count' plot, as it is a different metric
colors = palette + ['lightcoral']
for idx, (col, color) in enumerate(zip(['Mean', 'Std', 'Min', 'Max', 'Count'], colors)):
ax = axs[idx]
# Prepare the data for the current column
current_plot_data = plot_data[col].dropna()
# Remove outliers for better visualization
filtered_plot_data = current_plot_data[current_plot_data.between(current_plot_data.quantile(.03), current_plot_data.quantile(0.97))]
filtered_count = current_plot_data.count() - filtered_plot_data.count()
# Create boxplots
sns.boxplot(data=filtered_plot_data, ax=ax, color=color, showfliers=False, width=0.4) # type: ignore
# Add individual data points
sns.swarmplot(data=filtered_plot_data, ax=ax, color='black', size=3, alpha=0.6) # type: ignore
# Set labels and title
ax.set_title(f'{col} Distribution', fontsize=14, fontweight='bold')
ax.set_xticks([]) # Remove x-ticks for clarity
ax.set_xlabel('') # No x-label needed
ax.set_ylabel('Latency (ms)' if col != 'Count' else 'Count', fontsize=12)
# Calculate statistics of the statistics - here based on the original data with outliers!
data_values = plot_data[col]
first_line_length = len(f"Mean: {data_values.mean():.2f}")
second_line_length = len(f"Std: {data_values.std():.2f}")
third_line_length = len(f"Min: {data_values.min():.2f}")
fourth_line_length = len(f"Max: {data_values.max():.2f}")
fivth_line_length = len(f"Filtered: {filtered_count}")
max_length = max(first_line_length, second_line_length, third_line_length, fourth_line_length, fivth_line_length) + 1
# Prepare the text for the legend
stats_text = (
f"Mean:{' ' * (max_length - first_line_length)}{data_values.mean():.2f}\n"
f"Std:{' ' * (max_length - second_line_length)}{data_values.std():.2f}\n"
f"Min:{' ' * (max_length - third_line_length)}{data_values.min():.2f}\n"
f"Max:{' ' * (max_length - fourth_line_length)}{data_values.max():.2f}\n"
f"Filtered:{' ' * (max_length - fivth_line_length)}{filtered_count}"
)
# --- Place legend in the top right using axes fraction coordinates ---
ax.text(
0.95, 0.98, # axes fraction: 95% right, 98% up
stats_text,
transform=ax.transAxes,
verticalalignment='top',
horizontalalignment='right',
fontsize=10,
fontfamily='monospace',
bbox=dict(facecolor='white', alpha=0.9, boxstyle='round,pad=0.3', edgecolor='gray')
)
# Add grid for better readability
ax.grid(axis='y', linestyle='--', alpha=0.4)
# Set the overall title for the figure
plt.suptitle(
f'Statistics for Chain: {chain_name}\nAcross {len(chain_data)} Experiment Runs - {experiment_name}',
fontsize=18, fontweight='bold'
)
# Save the figure with a filename that includes the chain name
output_file = args.input.replace('.csv', f'_chain_{chain_name_fs}_analysis.png')
plt.savefig(output_file, dpi=300)
plt.close()
# Print summary statistics for the chain
summary = chain_data.describe()
print(f"\nSummary for chain: {chain_name}")
print(summary[['mean', 'std', 'min', 'max', 'count']])
print(f"\nAnalysis complete. Plots saved with base name: {args.input.replace('.csv', '_chain_*_analysis.png')}")
if __name__ == "__main__":
main()