The ability to visualise data is an essential skill for everyone who works with data. You may visualise your data using a number of different third-party modules that are available for Python. Matplotlib and its submodule pyplot are two of the most widely used plotting modules. The latter is often known to by its alias, plt. Matplotlib gives you access to a highly flexible tool known as plt.scatter(), which enables you to generate scatter plots that range in complexity from simple to advanced.
In the following paragraphs, you will go through a number of examples that will demonstrate how to make efficient use of the function. You will acquire the following skills by the end of this guide:
Use the required and optional
Customize scatter plots for
basic and more advanced plots
more than two dimensions
on a scatter plot
You should have prior experience with the principles of Python programming, as well as the concepts of NumPy and the ndarray object it uses, in order to get the most out of this course. To follow along with this tutorial, you don’t need to have any prior experience with Matplotlib; however, if you are interested in learning more about the plotting module, check out Python Plotting Using Matplotlib (Guide)
Creating Scatter Plots
The relationship between two variables may be graphically represented using something called a scatter plot. Scatter plots are a useful tool for analysing the connection between two variables, such as when trying to determine whether or not there is a correlation between the two.
You will get acquainted with the process of making simple scatter plots using Matplotlib by working through this portion of the lesson. You will learn how to further personalise your plots to depict more complicated data by utilising more than two dimensions in subsequent parts of this guide.
Getting Started With
Matplotlib will need to be installed on your computer before you can begin working with the plt.scatter() function. You may do this by using pip, which is the default package manager for Python. To do so, enter the following command in the terminal.
$ python -m pip install matplotlib
Consider the following use case after Matplotlib has been successfully installed on your system. There are six unique kinds of bottled orange drinks that are sold in a café. Since the proprietor is interested in understanding the connection between the pricing of the drinks and the quantity of each one that is sold, he meticulously records the number of units of each beverage that are transacted on a daily basis. The nature of this connection might be portrayed as
import matplotlib.pyplot as plt price = [2.50, 1.23, 4.02, 3.25, 5.00, 4.40] sales_per_day = [34, 62, 49, 22, 13, 19] plt.scatter(price, sales_per_day) plt.show()
You will need to use the alias plt in order to import the pyplot submodule that is provided by Matplotlib into this Python script. This alias is often used in accordance with the convention in order to condense the names of the module and the submodule. After that, you generate lists that include the pricing as well as the daily sales average for each of the six different orange drinks that were sold.
In the last step, the scatter plot is made by using the plt.scatter() function and passing it the two variables that you want to compare as the input parameters. Since you’re working with a Python script, you’ll also need to use the plt.show() function to make sure the figure is shown.
You do not need to invoke the plt.show() function when you are working in an interactive environment such as a console or a Jupyter Notebook. In this guide, each of the examples will be presented in the form of a script, and each script will include a call to the plt.show() function.
The following is what the following line of code produces:
This graph demonstrates that, in general, there is a negative correlation between the price of a drink and the quantity of sales. On the other hand, the beverage with a price of $4.02 stands out as an anomaly, which may indicate that it is a very in-demand item. When you use scatter plots in this manner, paying great attention to the details might assist you in investigating the connection between the variables. After that, you have the option of continuing the analysis, which can include linear regression or some other method.
A different method included inside matplotlib.pyplot may also be used to generate the scatter plot that was shown before. plt.plot(), which is part of Matplotlib, is a charting function that may be used for a variety of purposes and enables users to generate a wide variety of line and marker plots.
The scatter plot that you received in the previous section may be replicated by making the following call to the plt.plot() function, while maintaining the same parameters.
plt.plot(price, sales_per_day, "o") plt.show()
In this particular scenario, the marker “o” needed to be included as a third parameter since, without it, plt.plot() will only draw a line graph. The plots that you made with this code and the plots that you made before using the plt.scatter() function are exactly the same.
Using the plt.plot() function might be the better option in some circumstances when generating a simple scatter plot like the one shown in this example. Using the timeit module, you are able to examine how effectively the two routines compare to one another.
import timeit import matplotlib.pyplot as plt price = [2.50, 1.23, 4.02, 3.25, 5.00, 4.40] sales_per_day = [34, 62, 49, 22, 13, 19] print( "plt.scatter()", timeit.timeit( "plt.scatter(price, sales_per_day)", number=1000, globals=globals(), ), ) print( "plt.plot()", timeit.timeit( "plt.plot(price, sales_per_day, 'o')", number=1000, globals=globals(), ), )
When you execute this code, you’ll see that plt.plot() is a much more effective than plt.scatter(). The performance will be variable depending on the machine you use, but overall, you’ll notice a big difference. When I put the above example through its paces on my computer, the plt.plot() function was nearly seven times quicker.
Why would you ever use the plt.scatter() function when the plt.plot() function allows you to produce scatter plots and is also far faster? The solution may be found in the following sections of this guide. Plt.scatter() is required in order to make use of the vast majority of the customization options and advanced applications that are covered in this course. You may use the following as a general rule:
If you need a basic scatter plot, use
, especially if you want to prioritize performance.
If you want to customize your scatter plot by using more advanced plotting features, use
You’ll get a taste of some of the more complex applications of plt.scatter in the next section ()
Customizing Markers in Scatter Plots
The markers on a two-dimensional scatter plot may be customised to allow for the visualisation of more than two variables at once. There are four primary aspects of the markers that are used in a scatter plot that may be customised using the plt.scatter command ()
You will learn how to make changes to all of them in the next portion of the lesson.
Changing the Size
Let’s go back to the proprietor of the café that you spoke to previously in this course. The many orange drinks that he offers come from a variety of different vendors and each one has its own unique profit margin. Adjusting the size of the marker in the scatter plot allows you to display this extra information. In this example, the profit margin is expressed as a percentage.
import matplotlib.pyplot as plt import numpy as np price = np.asarray([2.50, 1.23, 4.02, 3.25, 5.00, 4.40]) sales_per_day = np.asarray([34, 62, 49, 22, 13, 19]) profit_margin = np.asarray([20, 35, 40, 20, 27.5, 15]) plt.scatter(x=price, y=sales_per_day, s=profit_margin * 10) plt.show()
As compared to the first illustration, there have been some modifications made. Arrays built with NumPy are replacing lists as your primary data structure. NumPy arrays are often used in these kinds of applications because they allow element-wise operations that may be carried out in an efficient manner. You are free to use any array-like data structure for the data. Since the NumPy module is required by Matplotlib, you do not need to install it by yourself because it is a dependent of Matplotlib.
In addition to this, the function call has been given input arguments in the form of named parameters. The parameters x and y are obligatory, but the rest of the parameters are up to the user’s discretion.
The dimension of the marker is indicated by the parameter s. In this particular illustration, the profit margin serves as a variable that is used to calculate the size of the marker, and then that value is multiplied by 10 so that the difference in size can be seen more plainly.
The scatter plot that was produced by using this code can be seen down below:
The magnitude of the marker represents the percentage of gross profit achieved by each product. The two orange beverages that sell the most also have the largest profit margins. Both of these drinks are orange. This is wonderful information for the café.
Changing the Color
A significant number of the patrons of the café make it a point to thoroughly examine all of the product information, particularly the sugar level of the beverages they want to purchase. The proprietor of the café wants to highlight his range of healthy meals in his next advertising campaign, so he classifies the beverages according to the amount of sugar they contain and utilises a traffic light system to show which drinks have a low, medium, or high amount of sugar content.
You may give each of the markers in the scatter plot a different colour to indicate the amount of sugar they contain.
# ... low = (0, 1, 0) medium = (1, 1, 0) high = (1, 0, 0) sugar_content = [low, high, medium, medium, high, low] plt.scatter( x=price, y=sales_per_day, s=profit_margin * 10, c=sugar_content, ) plt.show()
You specify the variables low, medium, and high to be tuples, with each tuple holding three values that, in this sequence, represent the red, green, and blue components of the colour. These are the values for the RGB colour model. The tuples that signify low, medium, and high, respectively, are represented by the colours green, yellow, and red.
After that, you categorised each beverage based on its value for the sugar content field. You may specify the colour of each marker by using the c argument, which is an optional part of the function call. Here is the scatter plot that was generated by running this code:
The proprietor of the coffee shop has already made up their mind to take the most costly beverage off the menu since it has a high sugar content and does not sell very well. Should he also cease stocking the cheapest of the beverages in order to increase the health credentials of the firm, despite the fact that it sells well and has a nice profit? Despite the fact that it sells well and has a high profit.
Changing the Shape
This was a really helpful exercise for the proprietor of the coffee shop, and now he would want to look into another product. In addition to the orange beverages, you will now additionally plot comparable data for the variety of cereal bars that are sold in the market.
import matplotlib.pyplot as plt import numpy as np low = (0, 1, 0) medium = (1, 1, 0) high = (1, 0, 0) price_orange = np.asarray([2.50, 1.23, 4.02, 3.25, 5.00, 4.40]) sales_per_day_orange = np.asarray([34, 62, 49, 22, 13, 19]) profit_margin_orange = np.asarray([20, 35, 40, 20, 27.5, 15]) sugar_content_orange = [low, high, medium, medium, high, low] price_cereal = np.asarray([1.50, 2.50, 1.15, 1.95]) sales_per_day_cereal = np.asarray([67, 34, 36, 12]) profit_margin_cereal = np.asarray([20, 42.5, 33.3, 18]) sugar_content_cereal = [low, high, medium, low] plt.scatter( x=price_orange, y=sales_per_day_orange, s=profit_margin_orange * 10, c=sugar_content_orange, ) plt.scatter( x=price_cereal, y=sales_per_day_cereal, s=profit_margin_cereal * 10, c=sugar_content_cereal, ) plt.show()
With this piece of code, you will need to modify the variable names so that they reflect the fact that you now have data for two distinct goods. After that, you combine the results of the two scatter plots into a single figure. The following results may be derived from this:
Regrettably, it is no longer possible for you to determine which data points belong to the cereal bars and which belong to the orange drinks. You have the option of modifying the form of the marker used for one of the dispersion
import matplotlib.pyplot as plt import numpy as np low = (0, 1, 0) medium = (1, 1, 0) high = (1, 0, 0) price_orange = np.asarray([2.50, 1.23, 4.02, 3.25, 5.00, 4.40]) sales_per_day_orange = np.asarray([34, 62, 49, 22, 13, 19]) profit_margin_orange = np.asarray([20, 35, 40, 20, 27.5, 15]) sugar_content_orange = [low, high, medium, medium, high, low] price_cereal = np.asarray([1.50, 2.50, 1.15, 1.95]) sales_per_day_cereal = np.asarray([67, 34, 36, 12]) profit_margin_cereal = np.asarray([20, 42.5, 33.3, 18]) sugar_content_cereal = [low, high, medium, low] plt.scatter( x=price_orange, y=sales_per_day_orange, s=profit_margin_orange * 10, c=sugar_content_orange, ) plt.scatter( x=price_cereal, y=sales_per_day_cereal, s=profit_margin_cereal * 10, c=sugar_content_cereal, marker="d", ) plt.show()
You do not change the default marker shape for the data pertaining to the orange drink. The letter “o,” which stands for a dot, is used as the default marker. You choose the form of the marker to be a diamond for the cereal bar data by selecting the letter “d,” which stands for a diamond marker. In the documentation page devoted to markers, you can discover a list of all of the possible markers that you may use. The two scatter plots are shown below, overlaid on the same picture as before:
You are now able to differentiate between the data points pertaining to the orange beverages and those pertaining to the cereal bars. But, the last storyline you crafted had one flaw, which you will investigate more in the next sentence.
Changing the Transparency
It seems that one of the data points pertaining to the orange drinks has vanished. Just five of the round markers have been shown in the illustration, however there should be six orange beverages total. A data point for an orange drink is concealed inside one of the data points for cereal bars.
You are able to rectify this issue with the visualisation by using the alpha value to make the data points partly visible.
# ... plt.scatter( x=price_orange, y=sales_per_day_orange, s=profit_margin_orange * 10, c=sugar_content_orange, alpha=0.5, ) plt.scatter( x=price_cereal, y=sales_per_day_cereal, s=profit_margin_cereal * 10, c=sugar_content_cereal, marker="d", alpha=0.5, ) plt.title("Sales vs Prices for Orange Drinks and Cereal Bars") plt.legend(["Orange Drinks", "Cereal Bars"]) plt.xlabel("Price (Currency Unit)") plt.ylabel("Average weekly sales") plt.text( 3.2, 55, "Size of marker = profit margin\n" "Color of marker = sugar content", ) plt.show()
You have both sets of markers configured to have an alpha value of 0.5, which indicates that they are semitransparent to the viewer. On this plot, you can now see all of the data points, even those that coincide with one another:
You’ve also added a title and some other labels to the plot in order to round out the figure and provide further information about what’s being shown.
Customizing the Colormap and Style
In the scatter plots that you’ve developed so far, you’ve assigned each of the beverages and cereal bars one of three colours to signify their level of sugar content: low, medium, or high. You are going to make the necessary adjustments now so that the colour accurately reflects the total amount of sugar that each item contains.
You must begin by refactoring the variables sugar content orange and sugar content cereal such that they not only reflect the RGB colour, but also the amount of sugar that is included in the orange and cereal.
sugar_content_orange = [15, 35, 22, 27, 38, 14] sugar_content_cereal = [21, 49, 29, 24]
These are now lists that include, for each item, the proportion of the daily recommended amount of sugar that is contained inside it. The remainder of the code has not been altered in any way, but you are now given the option to choose the colormap to use. This corresponds the values to
# ... plt.scatter( x=price_orange, y=sales_per_day_orange, s=profit_margin_orange * 10, c=sugar_content_orange, cmap="jet", alpha=0.5, ) plt.scatter( x=price_cereal, y=sales_per_day_cereal, s=profit_margin_cereal * 10, c=sugar_content_cereal, cmap="jet", marker="d", alpha=0.5, ) plt.title("Sales vs Prices for Orange Drinks and Cereal Bars") plt.legend(["Orange Drinks", "Cereal Bars"]) plt.xlabel("Price (Currency Unit)") plt.ylabel("Average weekly sales") plt.text( 2.7, 55, "Size of marker = profit margin\n" "Color of marker = sugar content", ) plt.colorbar() plt.show()
You have also shown the colorbar that serves as a legend for the colour of the markers, in addition to the fact that the colour of the markers is now determined by a scale that is continuous. The resultant scatter plot looks like this:
You have used the default Matplotlib presentation style for all of the plots that you have created up to this point. You have a number of choices available to you on how this style will appear when it is used. The following will allow you to show the many styles that are available:
>>> plt.style.available [ "Solarize_Light2", "_classic_test_patch", "bmh", "classic", "dark_background", "fast", "fivethirtyeight", "ggplot", "grayscale", "seaborn", "seaborn-bright", "seaborn-colorblind", "seaborn-dark", "seaborn-dark-palette", "seaborn-darkgrid", "seaborn-deep", "seaborn-muted", "seaborn-notebook", "seaborn-paper", "seaborn-pastel", "seaborn-poster", "seaborn-talk", "seaborn-ticks", "seaborn-white", "seaborn-whitegrid", "tableau-colorblind10", ]
If you use the following function call before using plt.scatter while working with Matplotlib, you will now have the ability to modify the plot style ()
import matplotlib.pyplot as plt import numpy as np plt.style.use("seaborn") # ...
This makes the aesthetic resemble that of Seaborn, which is an additional third-party visualisation tool. You can get an idea of the variety of styles available by redrawing the last scatter plot you exhibited earlier using the Seaborn style:
You may learn more about modifying plots in Matplotlib, and the documentation pages for Matplotlib also have further lessons for the software.
When you make scatter plots using plt.scatter(), you have the ability to show more than two different variables. The following is a list of the variables that are being represented by this.
|Average number sold||Y-axis|
|Profit margin||Marker size|
|Product type||Marker shape|
|Sugar content||Marker color|
plt.scatter() is a highly strong and adaptable function due to the fact that it can represent more than two variables.
plt.scatter() provides even more customization options for scatter plots than its predecessor. By the use of an example, you will learn how to mask data by making use of NumPy arrays and scatter plots in the next section. For the purpose of this demonstration, you will first produce a set of random data points, and then you will divide those points into two unique areas inside the same scatter plot.
A commuter who is interested in collecting information has, over the course of six months, compiled a list of the times that buses arrived at their local bus stop. She observed that the actual arrival times follow a typical distribution around these periods, despite the fact that the timetabled arrival times are at 15 minutes past the hour and 45 minutes past the hour.
This graph illustrates the relative probability that a bus will arrive at each minute throughout a period of one hour. NumPy and the np.linspace module may be used to create a representation of this probability distribution ()
import matplotlib.pyplot as plt import numpy as np mean = 15, 45 sd = 5, 7 x = np.linspace(0, 59, 60) # Represents each minute within the hour first_distribution = np.exp(-0.5 * ((x - mean) / sd) ** 2) second_distribution = 0.9 * np.exp(-0.5 * ((x - mean) / sd) ** 2) y = first_distribution + second_distribution y = y / max(y) plt.plot(x, y) plt.ylabel("Relative probability of bus arrivals") plt.xlabel("Minutes past the hour") plt.show()
You have compiled the results of two normal distributions that are focused on 15 minutes past the hour and 45 minutes past the hour. You divide the maximum possible arrival time by the most probable arrival time to get a value of one for the most likely arrival time.
You are now able to simulate the times that the bus will arrive using this distribution. In order to do this, you may make use of the built-in random module to generate random timings as well as random relative probabilities. You will also be making use of list comprehensions in the code that follows.
import random import matplotlib.pyplot as plt import numpy as np n_buses = 40 bus_times = np.asarray([random.randint(0, 59) for _ in range(n_buses)]) bus_likelihood = np.asarray([random.random() for _ in range(n_buses)]) plt.scatter(x=bus_times, y=bus_likelihood) plt.title("Randomly chosen bus arrival times and relative probabilities") plt.ylabel("Relative probability of bus arrivals") plt.xlabel("Minutes past the hour") plt.show()
You have simulated forty different bus arrivals, which may be represented graphically in the following scatter plot:
Since the data you are producing is random, the plot you create will have a different appearance. Nevertheless, not all of these points are likely to be close to the reality that the commuter saw from the data that she acquired and studied. This is because the commuter gathered and analysed a lot of data. The distribution that she derived from the data may be shown using the simulated bus if you so want.
import random import matplotlib.pyplot as plt import numpy as np mean = 15, 45 sd = 5, 7 x = np.linspace(0, 59, 60) first_distribution = np.exp(-0.5 * ((x - mean) / sd) ** 2) second_distribution = 0.9 * np.exp(-0.5 * ((x - mean) / sd) ** 2) y = first_distribution + second_distribution y = y / max(y) n_buses = 40 bus_times = np.asarray([random.randint(0, 59) for _ in range(n_buses)]) bus_likelihood = np.asarray([random.random() for _ in range(n_buses)]) plt.scatter(x=bus_times, y=bus_likelihood) plt.plot(x, y) plt.title("Randomly chosen bus arrival times and relative probabilities") plt.ylabel("Relative probability of bus arrivals") plt.xlabel("Minutes past the hour") plt.show()
The following results may be derived from this:
To ensure that the simulation is as accurate as possible, you need to check that the random bus arrivals correspond to the data and the distribution that was produced from those data. You may sort through the randomly produced points and retain just the ones that fit inside the probability distribution. This is one way to filter the data. You are able to do this by developing a mask for the scatter.
# ... in_region = bus_likelihood < y[bus_times] out_region = bus_likelihood >= y[bus_times] plt.scatter( x=bus_times[in_region], y=bus_likelihood[in_region], color="green", ) plt.scatter( x=bus_times[out_region], y=bus_likelihood[out_region], color="red", marker="x", ) plt.plot(x, y) plt.title("Randomly chosen bus arrival times and relative probabilities") plt.ylabel("Relative probability of bus arrivals") plt.xlabel("Minutes past the hour") plt.show()
NumPy arrays carrying Boolean values are included inside the variables in region and out region. These arrays’ contents are determined by whether the randomly generated likelihoods fall above or below the distribution y. You then proceed to create two different scatter plots, one for the points that fall inside the distribution and the other for the points that fall outside of the distribution. The data points that are above the distribution are not indicative of the actual data because:
You have divided the data points from the first scatter plot into two groups, one of which falls inside the distribution and the other of which does not, and you have designated each group with a distinctive colour and marker.
Reviewing the Key Input Parameters
In the previous sections, you gained knowledge of the primary input parameters that are used to generate scatter plots. The following is a condensed version of the most important takeaways from the primary input:
|These parameters represent the two main variables and can be any array-like data types, such as lists or NumPy arrays. These are required parameters.|
This parameter defines the size of the marker. It can be a
if all the markers have the same size or an array-like data structure if the markers have different sizes.
This parameter represents the color of the markers. It will typically be either an array of colors, such as RGB values, or a sequence of values that will be mapped onto a colormap using the parameter
||This parameter is used to customize the shape of the marker.|
If a sequence of values is used for the parameter
, then this parameter can be used to select the mapping between values and colors, typically by using one of the standard colormaps or a custom colormap.
This parameter is a float that can take any value between
and represents the transparency of the markers, where
represents an opaque marker.
plt.scatter() accepts a wide variety of input options in addition to the ones listed above. The documentation will provide you with access to the whole list of parameters that may be entered.