|
matplotlib.pyplot
matplotlib.pyplot
is one of the best-known Python packages for plotting 2D charts. In order to plot this
kind of element, the package works around the classes Figure, Subplot
and Axes
. Usually, it
is imported as plt
.
The Figure
corresponds to the canvas where the elements would be plotted. It can be created through the
figure
method, and we can specify its number and size through the optional parameters num
and
fig_size
.
If these parameters are not given, the default values will be assumed:
The last figure created become the active one, and any command will be applied to it, unless we call the method directly from a previous created figure.
The gcf
method returns a reference to the current active figure.
import matplotlib.pyplot as plt
import warnings
warnings.simplefilter("ignore")
plt.figure(num=1, figsize=(5,4))
<Figure size 360x288 with 0 Axes>
<Figure size 360x288 with 0 Axes>
Plotting some data in the figure, is done just through the plot
method, providing the data to plot.
After plotting the data, we just need to invoke the show
method.
import pandas as pd
from pandas.plotting import register_matplotlib_converters
import matplotlib.dates as mdates
mdates._reset_epoch_test_example()
mdates.set_epoch('0000-12-31T00:00:00') # old epoch (pre MPL 3.3)
register_matplotlib_converters()
data = pd.read_csv('data/algae.csv', index_col='date', parse_dates=True, infer_datetime_format=True)
plt.figure(figsize=(12,4))
plt.plot(data['pH'])
plt.show()
In our case, we plot the pH data recorded along time in the algae dataset. As we can see, the figure shows the pH values, between 5 and 10, recorded from 2018-09-30 to 2019-09-17. By default, the data index (date in our example) is used as labels in the abscissas axis, and the pH values in the ordinates axis.
In order to change the ordinates axis, we can change its limits in the plot, by invoking xlim
and
ylim
methods, given them the left and right for their intervals. It is also possible to add a title
to the plot and titles to the axes, as below.
plt.figure(figsize=(12,4))
plt.ylim(0, 14)
plt.title('pH along time')
plt.xlabel('date')
plt.ylabel('pH')
plt.plot(data['pH'])
plt.show()
Naturally, we can want to plot more than one chart in a figure, in order to do that we can split the figure with the
subplots
method.
This method receives the number of rows and columns to split the figure, and additional parameters to specify which
subplots will share the abscissas and ordinates, sharex
and sharey
optional parameters,
respectively.
subplots
returns the slitted figure and a bi-dimensional array of Axes
, one for each new part
of the figure. An Axes
is the class that encompasses the majority of elements in figures, such as the title,
the legend, but also the usual ones in charts, like the coordinate system, its labels, units, ticks, etc.
In this manner, to be able to plot different parts in a single figure, we have to invoke the methods to change the previous methods, through the axes object, as below.
In order to make it easier to configure, lets define some auxiliary functions to do it just once.
from matplotlib.font_manager import FontProperties
FONT_TEXT = FontProperties(size=6)
TEXT_MARGIN = 0.05
NR_COLUMNS: int = 3
HEIGHT: int = 4
WIDTH_PER_VARIABLE: int = 0.5
def choose_grid(nr):
if nr < NR_COLUMNS:
return 1, nr
else:
return (nr // NR_COLUMNS, NR_COLUMNS) if nr % NR_COLUMNS == 0 else (nr // NR_COLUMNS + 1, NR_COLUMNS)
def set_elements(ax: plt.Axes = None, title: str = '', xlabel: str = '', ylabel: str = '', percentage: bool = False):
if ax is None:
ax = plt.gca()
ax.set_title(title)
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
if percentage:
ax.set_ylim(0.0, 1.0)
return ax
from numpy import arange
from datetime import datetime
from matplotlib.dates import AutoDateLocator, AutoDateFormatter
def set_locators(xvalues: list, ax: plt.Axes = None, rotation: bool=False):
if isinstance(xvalues[0], datetime):
locator = AutoDateLocator()
ax.xaxis.set_major_locator(locator)
ax.xaxis.set_major_formatter(AutoDateFormatter(locator, defaultfmt='%Y-%m-%d'))
return None
elif isinstance(xvalues[0], str):
ax.set_xticks(arange(len(xvalues)))
if rotation:
ax.set_xticklabels(xvalues, rotation='90', fontsize='small', ha='center')
else:
ax.set_xticklabels(xvalues, fontsize='small', ha='center')
return None
else:
ax.set_xlim(xvalues[0], xvalues[-1])
ax.set_xticks(xvalues)
return None
The first one choose_grid
determines the best number of columns to show a set of charts, as a function of
the number of charts to show. The second, configures the axes, defining their labels and scaled. Finally the third one
deals with dates. Note the use gca
that returns the current axes, which is passed as a parameter to our
function.
With these functions is now simple to define functions to plot the usual charts in data science Our first of these functions is one for plotting a line chart.
config
file has some configuration parameters like colors.
import config as cfg
def plot_line(xvalues: list, yvalues: list, ax: plt.Axes = None, title: str = '', xlabel: str = '', ylabel: str = '',
percentage: bool = False, rotation: bool = False):
ax = set_elements(ax=ax, title=title, xlabel=xlabel, ylabel=ylabel, percentage=percentage)
set_locators(xvalues, ax=ax, rotation=rotation)
ax.plot(xvalues, yvalues, c=cfg.LINE_COLOR)
numeric_data = data._get_numeric_data()
rows = numeric_data.shape[1]
fig, axs = plt.subplots(rows, 1, figsize=(5*HEIGHT, rows*HEIGHT))
n = 0
for col in numeric_data:
plot_line(numeric_data.index, numeric_data[col], ax=axs[n], title=col, xlabel='date', ylabel=col)
n += 1
A similar approach is used to plot several series in a single chart. Our function multiple_line_chart
exemplifies it. Note that the series have to have the same index, and should have similar ranges for their values.
All the series in a dataframe satisfy the first constraint, and Phosphate and Orthophosphate satisfy the second too in our case study.
def multiple_line_chart(xvalues: list, yvalues: dict, ax: plt.Axes = None, title: str = '', xlabel: str = '',
ylabel: str = '', percentage: bool = False, rotation: bool = False):
ax = set_elements(ax=ax, title=title, xlabel=xlabel, ylabel=ylabel, percentage=percentage)
set_locators(xvalues, ax=ax, rotation=rotation)
legend: list = []
for name, y in yvalues.items():
ax.plot(xvalues, y)
legend.append(name)
ax.legend(legend)
two_series = {'Phosphate': data['Phosphate'], 'Orthophosphate': data['Orthophosphate']}
plt.figure(figsize=(12,4))
multiple_line_chart(data.index, two_series, title='Phosphate and Orthophosphate values', xlabel='date')
Bar charts are not so different from line ones. Indeed, functions for plotting them are very similar to the previous ones.
In the next example, the function bar_chart
is called to plot the frequency of each value for the 'season'
variable in our dataset.
def bar_chart(xvalues: list, yvalues: list, ax: plt.Axes = None, title: str = '', xlabel: str = '', ylabel: str = '',
percentage: bool = False, rotation: bool = False):
ax = set_elements(ax=ax, title=title, xlabel=xlabel, ylabel=ylabel, percentage=percentage)
set_locators(xvalues, ax=ax, rotation=rotation)
ax.bar(xvalues, yvalues, edgecolor=cfg.LINE_COLOR, color=cfg.FILL_COLOR, tick_label=xvalues)
for i in range(len(yvalues)):
ax.text(i, yvalues[i] + TEXT_MARGIN, f'{yvalues[i]:.2f}', ha='center', fontproperties=FONT_TEXT)
plt.figure()
counts = data['season'].value_counts()
bar_chart(counts.index, counts.values, title='season distribution', xlabel='season', ylabel='frequency', rotation=True)
Similarly, the multiple_bar_chart
plots a grouped bar chart, with each series corresponding to an entry
in the yvalues dictionary.
In our example, the frequency for fluid_velocity and river_depth values are plotted, since they share the same range.
def multiple_bar_chart(xvalues: list, yvalues: dict, ax: plt.Axes = None, title: str = '', xlabel: str = '', ylabel: str = '',
percentage: bool = False):
ax = set_elements(ax=ax, title=title, xlabel=xlabel, ylabel=ylabel, percentage=percentage)
ngroups = len(xvalues)
nseries = len(yvalues)
pos_group = arange(ngroups)
series_width = pos_group[1] - pos_group[0]
width = series_width / nseries - 0.1 * series_width
pos_center = pos_group + series_width/2 - 0.05 * series_width
pos_group = pos_group + width / 2
i = 0
legend = []
for metric in yvalues:
plt.bar(pos_group, yvalues[metric], width=width, edgecolor=cfg.LINE_COLOR, color=cfg.ACTIVE_COLORS[i])
values = yvalues[metric]
legend.append(metric)
for k in range(len(values)):
ax.text(pos_group[k], values[k] + TEXT_MARGIN, f'{values[k]:.2f}', ha='center', fontproperties=FONT_TEXT)
i += 1
pos_group = pos_group + i * width
ax.legend(legend, fontsize='x-small', title_fontsize='small')
plt.xticks(pos_center, xvalues)
two_series = {'river_depth': data['river_depth'].value_counts().sort_index(),
'fluid_velocity': data['fluid_velocity'].value_counts().sort_index()}
plt.figure()
multiple_bar_chart(['high', 'low', 'medium'], two_series, ylabel='frequency')