The Python Oracle

How does torch.argmax work for 4-dimensions

Become part of the top 3% of the developers by applying to Toptal https://topt.al/25cXVn

--

Music by Eric Matyas
https://www.soundimage.org
Track title: Techno Bleepage Open

--

Chapters
00:00 Question
01:27 Accepted answer (Score 4)
04:27 Answer 2 (Score 0)
04:58 Thank you

--

Full question
https://stackoverflow.com/questions/6342...

--

Content licensed under CC BY-SA
https://meta.stackexchange.com/help/lice...

--

Tags
#python #pytorch

#avk47



ACCEPTED ANSWER

Score 4


If k is a tensor of shape (2, 3, 4, 4), by definition, torch.argmax with axis=1 should give you an output of shape (2, 4, 4). To understand why this happens, you have to understand what happens in lower dimensions first.

If I have a 2D (2, 2) tensor A, like:

[[1,2],
 [3,4]]

Then torch.argmax(A, axis=1) gives the output of shape (2) with values (1, 1). The axis argument means axis along which to operate. So setting axis=1 means that it will look at values from each column one by one, before deciding a max. For row 0, it looks at column values 1, 2 and decides that 2 (at index 1) is the max. For row 1, it looks at column vales 3, 4 and decides that 4 (at index 1) is the max. So the argmax result is [1, 1].

Moving up to 3D, let's have a hypothetical array of dimensions (I, J, K). If we call argmax with axis = 1, we can break it down to the following:

I, J, K = 3, 4, 5
A = torch.rand(I, J, K)
out = torch.zeros((I, K), dtype=torch.int32)

for i in range(I):
    for k in range(K):
        out[i,k] = torch.argmax(A[i,:,k])
        
print(out)
print(torch.argmax(A, axis=1))

Out:
tensor([[3, 3, 2, 3, 2],
        [1, 1, 0, 1, 0],
        [0, 1, 0, 3, 3]], dtype=torch.int32)
tensor([[3, 3, 2, 3, 2],
        [1, 1, 0, 1, 0],
        [0, 1, 0, 3, 3]])

So what happens is, in your 3D tensor, you're once again calculating argmax along the columns/axis 1. So for each unique pair of (i, k), you have exactly J values along the axis 1, right? The index of the maximum value within those J values is inserted into position (i,k) of the output.

If you understand this, then you can understand what happens in 4D. For any 4D tensor of dimensions (I, J, K, L), if you call argmax with axis=1, then for each combination of (i, k, l) you'll have exactly J values along axis 1 - and the argmax of those J values will be present at output[i,k,l].

The keepdims argument is merely conserving the number of dimensions of your matrix. For example, argmax at axis 1 on the 4D matrix gives a 3D result of shape (I,K,L), but using keepdims, the result will be 4D as well with the shape (I,1,K,L).




ANSWER 2

Score 0


Argmax gives the index corresponding to highest value across a given dimension. so the number of dimensions is not an issue. so when you apply argmax across the given dimension, PyTorch by default collapses that dimension since its values are replaced by a single index. Now if you don't want to remove that dimension and instead keep it as one, then you could use keepdims=True.