118 lines
No EOL
5.3 KiB
Python
118 lines
No EOL
5.3 KiB
Python
import pandas as pd
|
|
import argparse
|
|
import seaborn as sns
|
|
import matplotlib.pyplot as plt
|
|
|
|
def parse_arguments():
|
|
parser = argparse.ArgumentParser(description='Analyze chain data from CSV file.')
|
|
parser.add_argument('--input', '-i', required=True, help='Path to the input CSV file')
|
|
return parser.parse_args()
|
|
|
|
def main():
|
|
args = parse_arguments()
|
|
|
|
# Load the CSV file from the input argument
|
|
df = pd.read_csv(args.input)
|
|
|
|
# Extract the experiment_name which should be the same across all rows
|
|
if 'experiment_name' not in df.columns:
|
|
raise ValueError("Input CSV must contain 'experiment_name' column.")
|
|
experiment_name = df['experiment_name'].iloc[0]
|
|
|
|
# Strip timestamp from experiment_name if it exists
|
|
experiment_name = experiment_name.split('-')[0] if '-' in experiment_name else experiment_name
|
|
|
|
# Group data by chain
|
|
chain_groups = df.groupby('chain')
|
|
|
|
# For each chain, create a figure with five subplots for boxplots (mean, std, min, max, count)
|
|
for chain_name, chain_data in chain_groups:
|
|
fig, axs = plt.subplots(1, 5, figsize=(18, 6), constrained_layout=True)
|
|
|
|
# Normalize chain name for filename
|
|
chain_name_fs = str(chain_name).replace('--> /', '-').replace('/', '_').replace(' ', '')
|
|
|
|
# Create a DataFrame with the columns we want to plot
|
|
plot_data = chain_data[['mean', 'std', 'min', 'max', 'count']].copy()
|
|
plot_data.columns = ['Mean', 'Std', 'Min', 'Max', 'Count']
|
|
|
|
# Make all plots have the same color palette
|
|
palette = sns.color_palette("husl", 4)
|
|
# Add a distinct color for the 'Count' plot, as it is a different metric
|
|
colors = palette + ['lightcoral']
|
|
|
|
for idx, (col, color) in enumerate(zip(['Mean', 'Std', 'Min', 'Max', 'Count'], colors)):
|
|
ax = axs[idx]
|
|
|
|
# Prepare the data for the current column
|
|
current_plot_data = plot_data[col].dropna()
|
|
# Remove outliers for better visualization
|
|
filtered_plot_data = current_plot_data[current_plot_data.between(current_plot_data.quantile(.03), current_plot_data.quantile(0.97))]
|
|
|
|
filtered_count = current_plot_data.count() - filtered_plot_data.count()
|
|
|
|
# Create boxplots
|
|
sns.boxplot(data=filtered_plot_data, ax=ax, color=color, showfliers=False, width=0.4) # type: ignore
|
|
|
|
# Add individual data points
|
|
sns.swarmplot(data=filtered_plot_data, ax=ax, color='black', size=3, alpha=0.6) # type: ignore
|
|
|
|
# Set labels and title
|
|
ax.set_title(f'{col} Distribution', fontsize=14, fontweight='bold')
|
|
ax.set_xticks([]) # Remove x-ticks for clarity
|
|
ax.set_xlabel('') # No x-label needed
|
|
ax.set_ylabel('Latency (ms)' if col != 'Count' else 'Count', fontsize=12)
|
|
|
|
# Calculate statistics of the statistics - here based on the original data with outliers!
|
|
data_values = plot_data[col]
|
|
first_line_length = len(f"Mean: {data_values.mean():.2f}")
|
|
second_line_length = len(f"Std: {data_values.std():.2f}")
|
|
third_line_length = len(f"Min: {data_values.min():.2f}")
|
|
fourth_line_length = len(f"Max: {data_values.max():.2f}")
|
|
fivth_line_length = len(f"Filtered: {filtered_count}")
|
|
max_length = max(first_line_length, second_line_length, third_line_length, fourth_line_length, fivth_line_length) + 1
|
|
# Prepare the text for the legend
|
|
|
|
stats_text = (
|
|
f"Mean:{' ' * (max_length - first_line_length)}{data_values.mean():.2f}\n"
|
|
f"Std:{' ' * (max_length - second_line_length)}{data_values.std():.2f}\n"
|
|
f"Min:{' ' * (max_length - third_line_length)}{data_values.min():.2f}\n"
|
|
f"Max:{' ' * (max_length - fourth_line_length)}{data_values.max():.2f}\n"
|
|
f"Filtered:{' ' * (max_length - fivth_line_length)}{filtered_count}"
|
|
)
|
|
|
|
# --- Place legend in the top right using axes fraction coordinates ---
|
|
ax.text(
|
|
0.95, 0.98, # axes fraction: 95% right, 98% up
|
|
stats_text,
|
|
transform=ax.transAxes,
|
|
verticalalignment='top',
|
|
horizontalalignment='right',
|
|
fontsize=10,
|
|
fontfamily='monospace',
|
|
bbox=dict(facecolor='white', alpha=0.9, boxstyle='round,pad=0.3', edgecolor='gray')
|
|
)
|
|
|
|
# Add grid for better readability
|
|
ax.grid(axis='y', linestyle='--', alpha=0.4)
|
|
|
|
# Set the overall title for the figure
|
|
plt.suptitle(
|
|
f'Statistics for Chain: {chain_name}\nAcross {len(chain_data)} Experiment Runs - {experiment_name}',
|
|
fontsize=18, fontweight='bold'
|
|
)
|
|
|
|
# Save the figure with a filename that includes the chain name
|
|
output_file = args.input.replace('.csv', f'_chain_{chain_name_fs}_analysis.png')
|
|
plt.savefig(output_file, dpi=300)
|
|
plt.close()
|
|
|
|
# Print summary statistics for the chain
|
|
summary = chain_data.describe()
|
|
print(f"\nSummary for chain: {chain_name}")
|
|
print(summary[['mean', 'std', 'min', 'max', 'count']])
|
|
|
|
print(f"\nAnalysis complete. Plots saved with base name: {args.input.replace('.csv', '_chain_*_analysis.png')}")
|
|
|
|
if __name__ == "__main__":
|
|
main() |