Goal: Master Matplotlib, Python’s foundational plotting library, to create insightful visualizations for data exploration and presentation in AI/ML workflows. Learn to generate line plots, bar charts, histograms, scatter plots, and customize them for clarity.
1. What is Matplotlib?
Matplotlib is a comprehensive Python library for creating static, animated, and interactive visualizations. Its pyplot
module provides a MATLAB-like interface for quick plotting, while its object-oriented API offers fine-grained control.
Key Components:
Figure: The top-level container for all plot elements.
Axes: The area where data is plotted (a single plot or subplot).
Axis: The x/y-axis with ticks, labels, and limits.
Artist: Everything visible on the figure (lines, text, legends).
2. Basic Plots
Line Plot
import matplotlib.pyplot as plt
import numpy as np
# Generate data
x = np.linspace(0, 10, 100)
y = np.sin(x)
# Create plot
plt.figure(figsize=(8, 4))
plt.plot(x, y, label="sin(x)", color="blue", linestyle="--", linewidth=2)
plt.title("Sine Wave")
plt.xlabel("x")
plt.ylabel("sin(x)")
plt.grid(True)
plt.legend()
plt.show()
Bar Chart
categories = ["A", "B", "C"]
values = [25, 40, 30]
plt.bar(categories, values, color=["red", "green", "blue"])
plt.title("Category Performance")
plt.xlabel("Category")
plt.ylabel("Value")
plt.show()
Histogram
data = np.random.randn(1000) # Normal distribution
plt.hist(data, bins=30, edgecolor="black", alpha=0.7)
plt.title("Data Distribution")
plt.xlabel("Value")
plt.ylabel("Frequency")
plt.show()
Scatter Plot
x = np.random.rand(50)
y = x + np.random.randn(50) * 0.1
plt.scatter(x, y, color="purple", alpha=0.5, marker="o")
plt.title("X vs Y Correlation")
plt.xlabel("X")
plt.ylabel("Y")
plt.show()
3. Customizing Plots
Styling
Colors: Use named colors (
"red"
), hex codes ("#FF5733"
), or RGB tuples.Markers:
o
,s
,^
,D
, etc.Line Styles:
-
,--
,-.
,:
.
plt.plot(x, y, color="#FF5733", linestyle="-.", marker="o", markersize=5)
Labels & Legends
plt.title("Customized Plot", fontsize=14, fontweight="bold")
plt.xlabel("X-Axis", fontsize=12)
plt.ylabel("Y-Axis", fontsize=12)
plt.legend(loc="upper right", fontsize=10)
Subplots
fig, axes = plt.subplots(2, 2, figsize=(10, 6)) # 2x2 grid
axes[0, 0].plot(x, y)
axes[0, 1].scatter(x, y)
axes[1, 0].bar(categories, values)
axes[1, 1].hist(data, bins=30)
plt.tight_layout() # Avoid overlapping
plt.show()
4. Advanced Features
Annotations
Highlight specific points:
plt.plot(x, y)
plt.annotate("Peak", xy=(np.pi/2, 1), xytext=(3, 0.8),
arrowprops=dict(facecolor="black", arrowstyle="->"))
Styling with plt.style
Use predefined styles:
plt.style.use("ggplot") # Options: seaborn, dark_background, etc.
Saving Figures
plt.savefig("plot.png", dpi=300, bbox_inches="tight")
5. Integrating with Pandas
Plot directly from DataFrames:
import pandas as pd
# Load data
df = pd.read_csv("titanic.csv")
# Grouped bar chart
survival_rate = df.groupby("Pclass")["Survived"].mean()
survival_rate.plot(kind="bar", color="skyblue")
plt.title("Survival Rate by Class")
plt.xlabel("Class")
plt.ylabel("Survival Rate")
plt.xticks(rotation=0)
plt.show()
6. Real-World Use Cases in AI/ML
Feature Analysis:
Plot distributions of input features (e.g.,
df["Age"].hist()
).
Model Evaluation:
Visualize training vs validation loss curves.
Plot confusion matrices or ROC curves.
Clustering:
Scatter plots for 2D/3D cluster visualization.
7. Best Practices
Avoid Overplotting: Use transparency (
alpha
) or jittering.Choose the Right Chart:
Line plots: Trends over time.
Bar charts: Categorical comparisons.
Scatter plots: Relationships between variables.
Label Clearly: Always include titles, axis labels, and legends.
8. Practice Exercise
Load the Titanic dataset (
pd.read_csv("titanic.csv")
).Plot a histogram of passenger ages.
Create a bar chart showing survival rates by gender.
Generate a scatter plot of
Age
vsFare
, colored by survival.
Solution:
# 1. Load data
titanic = pd.read_csv("titanic.csv")
# 2. Age histogram
plt.hist(titanic["Age"].dropna(), bins=20, edgecolor="black")
plt.title("Age Distribution")
plt.xlabel("Age")
plt.ylabel("Count")
plt.show()
# 3. Survival by gender
survival_gender = titanic.groupby("Sex")["Survived"].mean()
survival_gender.plot(kind="bar", color=["pink", "blue"])
plt.title("Survival Rate by Gender")
plt.ylabel("Survival Rate")
plt.xticks(rotation=0)
plt.show()
# 4. Age vs Fare scatter plot
plt.scatter(titanic["Age"], titanic["Fare"], c=titanic["Survived"], alpha=0.5)
plt.colorbar(label="Survived")
plt.title("Age vs Fare (Colored by Survival)")
plt.xlabel("Age")
plt.ylabel("Fare")
plt.show()
Key Takeaways
Matplotlib provides granular control over every plot element.
Use Pandas integration to streamline visualization from DataFrames.
Prioritize clarity and simplicity to communicate insights effectively.