Collectives™ on Stack Overflow
Find centralized, trusted content and collaborate around the technologies you use most.
Learn more about Collectives
Teams
Q&A for work
Connect and share knowledge within a single location that is structured and easy to search.
Learn more about Teams
Flattening is available in three forms in PyTorch
As a tensor method (
oop style
)
torch.Tensor.flatten
applied directly on a tensor:
x.flatten()
.
As a function (
functional form
)
torch.flatten
applied as:
torch.flatten(x)
.
As a module (
layer
nn.Module
)
nn.Flatten()
. Generally used in a model definition.
All three are identical and share the same implementation, the only difference being
nn.Flatten
has
start_dim
set to
1
by default to avoid flattening the first axis (usually the batch axis). While the other two flatten from
axis=0
to
axis=-1
-
i.e.
the entire tensor - if no arguments are given.
–
You can think of the job of
torch.flatten()
as to simply doing a flattening operation of the tensor, without any strings attached. You give a tensor, it flattens, and returns it. That's all there to it.
On the contrary,
nn.Flatten()
is much more sophisticated (i.e., it's a neural net layer). Being object oriented, it inherits from
nn.Module
, although it
internally uses the plain tensor.flatten() OP in the
forward()
method
for flattening the tensor. You can think of it more like a syntactic sugar over
torch.flatten()
.
Important difference
: A notable distinction is that
torch.flatten()
always
returns an 1D tensor as result, provided that the input is at least 1D or greater, whereas
nn.Flatten()
always
returns a 2D tensor, provided that the input is at least 2D or greater (With 1D tensor as input, it will throw an
IndexError
).
Comparisons:
torch.flatten()
is an API whereas
nn.Flatten()
is a neural net layer.
torch.flatten()
is a python
function
whereas
nn.Flatten()
is a python
class
.
because of the above point,
nn.Flatten()
comes with lot of methods and attributes
torch.flatten()
can be used in the wild (e.g., for simple tensor OPs) whereas
nn.Flatten()
is expected to be used in a
nn.Sequential()
block as one of the layers.
torch.flatten()
has no information about the computation graph unless it is stuck into other graph-aware block (with
tensor.requires_grad
flag set to
True
) whereas
nn.Flatten()
is always being tracked by autograd.
torch.flatten()
cannot accept and process (e.g., linear/conv1D) layers as inputs whereas
nn.Flatten()
is mostly used for processing these neural net layers.
both
torch.flatten()
and
nn.Flatten()
return
views
to input tensor. Thus, any modification to the result also affects the input tensor. (See the code below)
Code demo
:
# input tensors to work with
In [109]: t1 = torch.arange(12).reshape(3, -1)
In [110]: t2 = torch.arange(12, 24).reshape(3, -1)
In [111]: t3 = torch.arange(12, 36).reshape(3, 2, -1) # 3D tensor
Flattening with torch.flatten()
:
In [113]: t1flat = torch.flatten(t1)
In [114]: t1flat
Out[114]: tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
# modification to the flattened tensor
In [115]: t1flat[-1] = -1
# input tensor is also modified; thus flattening is a view.
In [116]: t1
Out[116]:
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, -1]])
Flattening with nn.Flatten()
:
In [123]: nnfl = nn.Flatten()
In [124]: t3flat = nnfl(t3)
# note that the result is 2D, as opposed to 1D with torch.flatten
In [125]: t3flat
Out[125]:
tensor([[12, 13, 14, 15, 16, 17, 18, 19],
[20, 21, 22, 23, 24, 25, 26, 27],
[28, 29, 30, 31, 32, 33, 34, 35]])
# modification to the result
In [126]: t3flat[-1, -1] = -1
# input tensor also modified. Thus, flattened result is a view.
In [127]: t3
Out[127]:
tensor([[[12, 13, 14, 15],
[16, 17, 18, 19]],
[[20, 21, 22, 23],
[24, 25, 26, 27]],
[[28, 29, 30, 31],
[32, 33, 34, -1]]])
tidbit: torch.flatten()
is the precursor to nn.Flatten()
and its brethren nn.Unflatten()
since it existed from the very beginning. Then, there was a legitimate use-case for nn.Flatten()
, since this is a common requirement for almost all ConvNets (just before the softmax or elsewhere). So it was added later on in the PR #22245.
There are also recent proposals to use nn.Flatten()
in ResNets for model surgery.
Thanks for contributing an answer to Stack Overflow!
- Please be sure to answer the question. Provide details and share your research!
But avoid …
- Asking for help, clarification, or responding to other answers.
- Making statements based on opinion; back them up with references or personal experience.
To learn more, see our tips on writing great answers.