Top 5 Useful Graphs in Matplotlib
In Data Sciences, data visualization is a step that occurs throughout a project: in the research and analysis phase as well as in the final presentation to the client. It is therefore essential to know the most useful visualization techniques and how to interpret these graphs. So I invite you to discover 5 useful graphs using Matplotlib!
1. Scatter
Let's start with the basics: scatter plots. These graphs are the easiest to create and to interpret. They are very useful to detect any linear or polynomial trend between 2 variables of your dataset. To create them with Matplotlib, I recommend you to use the plt.scatter function.
The main disadvantage is that it is not easy to observe the relationship between more than 2 variables at the same time, but the code below shows you a little trick to observe the relationship between several variables on a single figure.
import
numpy as np
import
matplotlib.pyplot as plt
from
sklearn.datasets
import
load_iris
# loading an iris flower dataset
iris
=
load_iris()
X
=
iris.data
# X contains 150 samples, 4 variables
y
=
iris.target
# y contains 150 samples, 3 classes
# creating a for loop to display all graphics
n
=
X.shape[
1
]
plt.figure(figsize
=
(
12
,
8
))
for
i
in
range
(n):
plt.subplot(n
/
/
2
, n
/
/
2
, i
+
1
)
plt.scatter(X[:,
0
], X[:, i], c
=
y)
plt.xlabel(
'0'
)
plt.ylabel(i)
plt.colorbar(ticks
=
list
(np.unique(y)))
plt.show()
2. 3D Graphs
3D Scatter with Matplotlib
3D graphs are also essential in Data Science and Mathematics. Obviously, their primary purpose is to extend 2D graphs (like the ones above) to visualize relationships between 3 variables at a time. To create 3D graphs, you have to start by loading the 3D module of Matpotlib (mplot3d) and its Axes3D object. Then, you have to create an axis object with a 3D projection, which opens the possibility to draw surfaces, scatters, lines, and other 3D graphs.
%
matplotlib
# opens a QT5 window to zoom and move your graph!
import
matplotlib.pyplot as plt
from
mpl_toolkits.mplot3d
import
Axes3D
# Matplotlib 3D module
ax
=
plt.axes(projection
=
'3d'
)
# creates a "3D axis" object
ax.scatter(x[:,
0
], x[:,
1
], x[:,
2
], c
=
y)
Warning: 3D graphs are not very useful in Data Science and Machine Learning. Here are 2 arguments to justify this statement:
- When compared to a dataset with hundreds of dimensions, a 3D graph is not much better than a 2D graph in terms of representation. Ideally, we should use 1000-dimensional graphs, but this does not exist.
- 2D graphs are easier to interpret than 3D graphs. Besides, we often come back to 2D projections when we use 3D graphs (these are contourplots)
However, I would like to give 3D graphs a special purpose: to impress your colleagues if you have to represent mathematical functions with surfaces
3D Surfaceplot
3D graphics have an important part to play in mathematics and physics. They are very useful to represent surfaces and functions with 2 parameters.
Imagine trying to observe the surface generated from the function f(x, y) = sin(x) + cos(y) on a domain of x, y in [0, 5]. To do this, we need to compute each value of f(x, y) on a 2-dimensional grid (X, Y). This 2D grid is commonly called meshgrid in engineering and mathematics. To create this grid, we have to pass the values of x, y in a numpy function called np.meshgrid(x, y). Then, we can plot the surface f(X, Y) as a function of X and Y with the function surface_plot.
# Creating a function
f
=
lambda
x, y: np.sin(x)
+
np.cos(y)
# Creating a domain X, Y
x
=
np.linspace(
0
,
5
,
100
)
y
=
np.linspace(
0
,
5
,
100
)
X, Y
=
np.meshgrid(x, y)
# Result: 2D numpy array
Z
=
f(X, Y)
# D
isplaying the surfaceax
=
plt.axes(projection
=
'3d'
)
ax.plot_surface(X, Y, Z, cmap
=
'plasma'
)
3. Matplotlib Histograms
The histogram is the ideal graph to visualize the distribution of your data according to one or more quantitative variables. In statistics, this is called visualizing a distribution. The principle is simple: you create several intervals between the minimum and maximum of your data and then count the number of samples that are included between each interval. For example, here is a histogram showing the number (or frequency) of individuals in different age categories:
In any data analysis, we start by using histograms to visualize the distribution of our data: What is the weight of our individuals? What is the price of the apartments in the dataset? Very often, we get a histogram with a Gaussian bell shape, a sign that our data are normally distributed, but other types of distributions are not uncommon (uniform or logarithmic distribution).
To create a histogram with Matplotlib, you have to use the function plt.hist(). In this function, you have to pass a one-dimensional NumPy array. So make sure you select only one column in your dataset, or think about using the nd.ravel() method to "flatten" your NumPy arrays.
To draw a nice histogram, it is important to define the number of intervals to use. To do this, you must use the bins argument. I advise you to test different numbers of intervals. Here is an example of use:
To draw a beautiful histogram, it is necessary to define the number of intervals to use. To do this, you must use the bins argument. I advise you to test different numbers of intervals. Here is an example:
# G
enerating Data: 1000 normal random pointsx
=
np.random.randn(
1000
)
# visualizing over 50 intervals
plt.hist(x, bins
=
50
)
4. Matplotlib Contour Plots
Contour plots are useful and simple to interpret. They represent a top view of 3D graphics. Similar to a relief map, where the different levels of altitude are represented by contours, they are used to visualize mathematical functions and find a minimum or maximum by eye.
These graphs are built on the same principle as a 3D graph: you need to create a 2D array of our X, Y axes with the meshgrid function, then use these arrays in the plt.contour() or contourf() function. The levels argument allows you to draw the contour function with more or less altitude levels.
X
=
np.linspace(
0
,
5
,
100
)
Y
=
np.linspace(
0
,
5
,
100
)
X, Y
=
np.meshgrid(X, Y)
Z
=
f(X, Y)
plt.contour(X, Y, Z, levels
=
40
)
5. Matplotlib Imshow
Finally, Matplotlib's imshow() graphs are among the most useful in all of data science. It simply displays a 2D matrix in the same way as the pixels of a 2D image. We can thus visualize confusion matrices, hollow matrices, images, or even contour plots with imshow()! To use this function, you just have to provide it with a 2D NumPy array
plt.figure(figsize
=
(
12
,
3
))
#
Simple
imshow() graphX
=
np.random.randn(
50
,
50
)
plt.subplot(
131
)
plt.imshow(X)
# Iris correlation matrix
from
sklearn.datasets
import
load_iris
iris
=
load_iris()
X
=
iris.data
y
=
iris.target
plt.subplot(
132
)
plt.imshow(np.corrcoef(X.T, y))
# Matrix: f(X, Y) = sin(X) + cos(Y)
X
=
np.linspace(
0
,
5
,
100
)
Y
=
np.linspace(
0
,
5
,
100
)
X, Y
=
np.meshgrid(X, Y)
f = lambda: X, Y: np.sin(X) + np.cos(Y)
plt.subplot(
133
)
plt.imshow(f(X, Y))
plt.colorbar()
Review
Congratulations! You now know 5 very useful Matplotlib graphs in Data Science. I advise you to use Histograms and Scatter plots in the first phase of each Machine Learning project. If you are working with a Pandas DataFrame, I strongly recommend to train yourself to Seaborn, a particularly simple and efficient library built from Matplotlib. 3D graphs, imshow and contour plots are more useful for science and engineering.
Post a Comment