How to flatten only some Dimensions of a NumPy array

Borislav Hadzhiev

Last updated: Apr 12, 2024

Reading timeยท3 min

- How to flatten only some Dimensions of a NumPy array
- Using -1 as a shape dimension when flattening the array
- Flatten only some Dimensions of a NumPy array using numpy.vstack()

**Use the numpy.reshape() method to flatten only some dimensions of a NumPy
array.**

**The method will flatten the array, giving it a new shape, without changing its
data.**

main.py

`import numpy as np arr = np.zeros((2, 4, 2)) print(arr) print('-' * 50) new_arr = arr.reshape(8, 2) print(new_arr)`

The code for this article is available on GitHub

Running the code sample produces the following output.

shell

`[[[0. 0.] [0. 0.] [0. 0.] [0. 0.]] [[0. 0.] [0. 0.] [0. 0.] [0. 0.]]] -------------------------------------------------- [[0. 0.] [0. 0.] [0. 0.] [0. 0.] [0. 0.] [0. 0.] [0. 0.] [0. 0.]]`

The numpy.reshape() method gives a new shape to an array without changing its data.

The only arguments we passed to the `reshape()`

method are 2 integers that
represent the new shape.

The new shape should be compatible with the original shape.

main.py

`import numpy as np arr = np.zeros((2, 4, 2)) new_arr = arr.reshape(8, 2) # [[0. 0.] # [0. 0.] # [0. 0.] # [0. 0.] # [0. 0.] # [0. 0.] # [0. 0.] # [0. 0.]] print(new_arr)`

The code for this article is available on GitHub

`-1`

as a shape dimension when flattening the arrayOne shape dimension can be `-1`

.

When a shape dimension is `-1`

, its value is inferred from the length of the
array and remaining dimensions.

main.py

`import numpy as np arr = np.zeros((2, 4, 2)) new_arr = arr.reshape(-1, arr.shape[-1]) # [[0. 0.] # [0. 0.] # [0. 0.] # [0. 0.] # [0. 0.] # [0. 0.] # [0. 0.] # [0. 0.]] print(new_arr)`

The code for this article is available on GitHub

The last shape dimension in the example is `2`

, so the first shape dimension
(`-1`

) is inferred to be `8`

.

main.py

`import numpy as np arr = np.zeros((2, 4, 2)) print(arr.shape[-1]) # ๐๏ธ 2`

This is calculated by dividing the total size of the array (`16`

) by the product
of all other specified dimensions (`2`

) = `8`

.

main.py

`import numpy as np arr = np.zeros((2, 4, 2)) print(arr.size) # ๐๏ธ 16 print(arr.shape[-1]) # ๐๏ธ 2 print(arr.size // arr.shape[-1]) # ๐๏ธ 8`

The code sample flattens all but the last dimension.

You can use a similar approach to flatten all but the last two dimensions.

main.py

`import numpy as np arr = np.zeros((2, 4, 2, 4)) new_arr = arr.reshape(-1, *arr.shape[-2:]) print(new_arr)`

Running the code sample produces the following output.

shell

`[[[0. 0. 0. 0.] [0. 0. 0. 0.]] [[0. 0. 0. 0.] [0. 0. 0. 0.]] [[0. 0. 0. 0.] [0. 0. 0. 0.]] [[0. 0. 0. 0.] [0. 0. 0. 0.]] [[0. 0. 0. 0.] [0. 0. 0. 0.]] [[0. 0. 0. 0.] [0. 0. 0. 0.]] [[0. 0. 0. 0.] [0. 0. 0. 0.]] [[0. 0. 0. 0.] [0. 0. 0. 0.]]]`

You should always make sure that the new shape is compatible with the original shape.

`numpy.vstack()`

You can also use the numpy.vstack() method to flatten only some dimensions of a NumPy array.

The method stacks the arrays in a sequence vertically (row-wise).

main.py

`import numpy as np arr = np.zeros((2, 4, 2)) print(arr) print('-' * 50) new_arr = np.vstack(arr) print(new_arr) print('-' * 50) print(new_arr.shape)`

The code for this article is available on GitHub

Running the code sample produces the following output.

shell

`[[[0. 0.] [0. 0.] [0. 0.] [0. 0.]] [[0. 0.] [0. 0.] [0. 0.] [0. 0.]]] -------------------------------------------------- [[0. 0.] [0. 0.] [0. 0.] [0. 0.] [0. 0.] [0. 0.] [0. 0.] [0. 0.]] -------------------------------------------------- (8, 2)`

The `numpy.vstack()`

method is equivalent to concatenating along the first axis
after 1-D arrays of shape (N,) have been reshaped to (1, N).

You can learn more about the related topics by checking out the following tutorials:

- Pandas: Find length of longest String in DataFrame column
- How to add a Count Column to a Pandas DataFrame
- How to swap two DataFrame columns in Pandas
- Pandas: Make new Column from string Slice of another Column
- Get N random Rows from a NumPy Array in Python
- Create Date column from Year, Month and Day in Pandas
- numpy.linalg.LinAlgError: Singular matrix [Solved]
- NumPy: Unable to allocate array with shape and data type
- NumPy RuntimeWarning: invalid value encountered in divide
- TypeError: type numpy.ndarray doesn't define __round__ method