Pandas: Create Scatter plot from multiple DataFrame columns

avatar
Borislav Hadzhiev

Last updated: Apr 12, 2024
4 min

banner

# Table of Contents

  1. Pandas: Create Scatter plot from multiple DataFrame columns
  2. Setting the names of the labels
  3. Pandas: Create Scatter plot from multiple DataFrame columns using a for loop

# Pandas: Create Scatter plot from multiple DataFrame columns

Set the ax argument when calling DataFrame.plot() to create a scatter plot from multiple DataFrame columns in Pandas.

The ax argument enables us to set the axes of the current figure.

main.py
import pandas as pd import matplotlib.pyplot as plt df = pd.DataFrame( [[1.038988, 1.15915721, -2.28067047, 2.15586249], [1.67701311, -0.29835294, 1.85995237, 0.09892021], [-1.27494246, -1.10266464, -0.53293884, 0.65932129], [1.57701142, -0.75032143, 1.06973893, 0.82543657], [-0.99777099, -0.46051326, 1.43704249, 1.24864683], [-0.73344725, 1.63418558, 0.7973022, -1.9192436] ], columns=['A', 'B', 'C', 'D'] ) ax1 = df.plot(kind='scatter', x='A', y='B', color='g') ax2 = df.plot(kind='scatter', x='C', y='D', color='r', ax=ax1) plt.show()

pandas create scatter plot from multiple columns

The code for this article is available on GitHub

The DataFrame.plot() method makes plots of a DataFrame.

main.py
ax1 = df.plot(kind='scatter', x='A', y='B', color='g') ax2 = df.plot(kind='scatter', x='C', y='D', color='r', ax=ax1)

We passed the following arguments to the method:

  • kind - the kind of plot to produce.
  • x - the label along the x axis.
  • y - the label along the y axis.
  • color - the marker colors.
  • ax - the axes of the current figure.

The ax argument determines the axes the plot is drawn into.

If you don't supply the a argument, the DataFrame.plot() method creates a new plot and axes.

Notice that we didn't pass an ax argument in the first call to plot().

main.py
ax1 = df.plot(kind='scatter', x='A', y='B', color='g') ax2 = df.plot(kind='scatter', x='C', y='D', color='r', ax=ax1)

Instead, we used the produced from the first call axes as the ax argument in the second call.

If you have multiple calls to the DataFrame.plot() method, you would still reuse the same value for the ax argument.

main.py
ax1 = df.plot(kind='scatter', x='A', y='B', color='g') ax2 = df.plot(kind='scatter', x='C', y='D', color='r', ax=ax1) ax3 = df.plot(kind='scatter', x='E', y='F', color='b', ax=ax1)

To create a scatter plot using multiple DataFrame columns, the ax argument in the subsequent DataFrame calls has to be the same (ax1).

You can verify that this is the case by comparing the axes that are returned from DataFrame.plot().

main.py
import pandas as pd import matplotlib.pyplot as plt df = pd.DataFrame( [[1.038988, 1.15915721, -2.28067047, 2.15586249], [1.67701311, -0.29835294, 1.85995237, 0.09892021], [-1.27494246, -1.10266464, -0.53293884, 0.65932129], [1.57701142, -0.75032143, 1.06973893, 0.82543657], [-0.99777099, -0.46051326, 1.43704249, 1.24864683], [-0.73344725, 1.63418558, 0.7973022, -1.9192436] ], columns=['A', 'B', 'C', 'D'] ) ax1 = df.plot(kind='scatter', x='A', y='B', color='g') ax2 = df.plot(kind='scatter', x='C', y='D', color='r', ax=ax1) print(ax1 == ax2) # 👉️ True plt.show()

comparing axes returned from df plot

The code for this article is available on GitHub

The equality comparison will return True if the ax argument is supplied correctly.

# Setting the names of the labels

You can also set the names of the labels when calling DataFrame.plot().

main.py
import pandas as pd import matplotlib.pyplot as plt df = pd.DataFrame( [[1.038988, 1.15915721, -2.28067047, 2.15586249], [1.67701311, -0.29835294, 1.85995237, 0.09892021], [-1.27494246, -1.10266464, -0.53293884, 0.65932129], [1.57701142, -0.75032143, 1.06973893, 0.82543657], [-0.99777099, -0.46051326, 1.43704249, 1.24864683], [-0.73344725, 1.63418558, 0.7973022, -1.9192436] ], columns=['A', 'B', 'C', 'D'] ) ax1 = df.plot( kind='scatter', x='A', y='B', color='g', label='First' ) ax2 = df.plot( kind='scatter', x='C', y='D', color='r', ax=ax1, label='Second' ) plt.show()

setting the names of labels when calling dataframe plot

The code for this article is available on GitHub

If you need to add axis labels to the plot, call the set_xlabel() and set_ylabel() methods on ax1.

main.py
import pandas as pd import matplotlib.pyplot as plt df = pd.DataFrame( [[1.038988, 1.15915721, -2.28067047, 2.15586249], [1.67701311, -0.29835294, 1.85995237, 0.09892021], [-1.27494246, -1.10266464, -0.53293884, 0.65932129], [1.57701142, -0.75032143, 1.06973893, 0.82543657], [-0.99777099, -0.46051326, 1.43704249, 1.24864683], [-0.73344725, 1.63418558, 0.7973022, -1.9192436] ], columns=['A', 'B', 'C', 'D'] ) ax1 = df.plot( kind='scatter', x='A', y='B', color='g', label='First' ) ax2 = df.plot( kind='scatter', x='C', y='D', color='r', ax=ax1, label='Second' ) ax1.set_xlabel('horizontal label') ax1.set_ylabel('vertial label') plt.show()

call set label methods on ax1

I've written a detailed guide on how to add axis labels to a plot in Pandas.

# Pandas: Create Scatter plot from multiple DataFrame columns using a for loop

You can also use a for loop if you want to make your code more concise and reusable.

main.py
import pandas as pd import matplotlib.pyplot as plt df = pd.DataFrame( [[1.038988, 1.15915721, -2.28067047, 2.15586249], [1.67701311, -0.29835294, 1.85995237, 0.09892021], [-1.27494246, -1.10266464, -0.53293884, 0.65932129], [1.57701142, -0.75032143, 1.06973893, 0.82543657], [-0.99777099, -0.46051326, 1.43704249, 1.24864683], [-0.73344725, 1.63418558, 0.7973022, -1.9192436] ], columns=['A', 'B', 'C', 'D'] ) f, ax = plt.subplots(1) for x, y, color in zip(['A', 'C'], ['B', 'D'], ['g', 'r']): df.plot( kind='scatter', x=x, y=y, color=color, ax=ax, label=f'{x} vs {y}' ) ax.set_xlabel('horizontal label') ax.set_ylabel('vertial label') plt.show()

create scatter plot from multiple dataframe columns using for loop

The code for this article is available on GitHub

The zip() function iterates over several iterables in parallel and produces tuples with an item from each iterable.

main.py
# [('A', 'B', 'g'), ('C', 'D', 'r')] print(list(zip(['A', 'C'], ['B', 'D'], ['g', 'r'])))

On each iteration of the for loop, we call the DataFrame.plot() method.

The matplotlib.pyplot.subplots() method creates a figure and a set of subplots.

main.py
f, ax = plt.subplots(1)

Notice that the ax argument gets set to the same axes on each iteration.

main.py
for x, y, color in zip(['A', 'C'], ['B', 'D'], ['g', 'r']): df.plot( kind='scatter', x=x, y=y, color=color, ax=ax, label=f'{x} vs {y}' )

The three lists we passed to the zip function are:

  1. The values for the x argument in the DataFrame.plot() calls.
  2. The values for the y argument.
  3. The values for the color argument.

# Additional Resources

You can learn more about the related topics by checking out the following tutorials:

I wrote a book in which I share everything I know about how to become a better, more efficient programmer.
book cover
You can use the search field on my Home Page to filter through all of my articles.