Collectives™ on Stack Overflow

Find centralized, trusted content and collaborate around the technologies you use most.

Learn more about Collectives

Teams

Q&A for work

Connect and share knowledge within a single location that is structured and easy to search.

Learn more about Teams

Flattening is available in three forms in PyTorch

  • As a tensor method ( oop style ) torch.Tensor.flatten applied directly on a tensor: x.flatten() .

  • As a function ( functional form ) torch.flatten applied as: torch.flatten(x) .

  • As a module ( layer nn.Module ) nn.Flatten() . Generally used in a model definition.

    All three are identical and share the same implementation, the only difference being nn.Flatten has start_dim set to 1 by default to avoid flattening the first axis (usually the batch axis). While the other two flatten from axis=0 to axis=-1 - i.e. the entire tensor - if no arguments are given.

    Found this: torch/csrc/utils/tensor_flatten.h . Seems like it uses view , which is a reshape! Ivan Feb 2, 2021 at 18:22

    You can think of the job of torch.flatten() as to simply doing a flattening operation of the tensor, without any strings attached. You give a tensor, it flattens, and returns it. That's all there to it.

    On the contrary, nn.Flatten() is much more sophisticated (i.e., it's a neural net layer). Being object oriented, it inherits from nn.Module , although it internally uses the plain tensor.flatten() OP in the forward() method for flattening the tensor. You can think of it more like a syntactic sugar over torch.flatten() .

    Important difference : A notable distinction is that torch.flatten() always returns an 1D tensor as result, provided that the input is at least 1D or greater, whereas nn.Flatten() always returns a 2D tensor, provided that the input is at least 2D or greater (With 1D tensor as input, it will throw an IndexError ).

    Comparisons:

  • torch.flatten() is an API whereas nn.Flatten() is a neural net layer.

  • torch.flatten() is a python function whereas nn.Flatten() is a python class .

  • because of the above point, nn.Flatten() comes with lot of methods and attributes

  • torch.flatten() can be used in the wild (e.g., for simple tensor OPs) whereas nn.Flatten() is expected to be used in a nn.Sequential() block as one of the layers.

  • torch.flatten() has no information about the computation graph unless it is stuck into other graph-aware block (with tensor.requires_grad flag set to True ) whereas nn.Flatten() is always being tracked by autograd.

  • torch.flatten() cannot accept and process (e.g., linear/conv1D) layers as inputs whereas nn.Flatten() is mostly used for processing these neural net layers.

  • both torch.flatten() and nn.Flatten() return views to input tensor. Thus, any modification to the result also affects the input tensor. (See the code below)

    Code demo :

    # input tensors to work with
    In [109]: t1 = torch.arange(12).reshape(3, -1)
    In [110]: t2 = torch.arange(12, 24).reshape(3, -1)
    In [111]: t3 = torch.arange(12, 36).reshape(3, 2, -1)   # 3D tensor
    

    Flattening with torch.flatten():

    In [113]: t1flat = torch.flatten(t1)
    In [114]: t1flat
    Out[114]: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11])
    # modification to the flattened tensor    
    In [115]: t1flat[-1] = -1
    # input tensor is also modified; thus flattening is a view.
    In [116]: t1
    Out[116]: 
    tensor([[ 0,  1,  2,  3],
            [ 4,  5,  6,  7],
            [ 8,  9, 10, -1]])
    

    Flattening with nn.Flatten():

    In [123]: nnfl = nn.Flatten()
    In [124]: t3flat = nnfl(t3)
    # note that the result is 2D, as opposed to 1D with torch.flatten
    In [125]: t3flat
    Out[125]: 
    tensor([[12, 13, 14, 15, 16, 17, 18, 19],
            [20, 21, 22, 23, 24, 25, 26, 27],
            [28, 29, 30, 31, 32, 33, 34, 35]])
    # modification to the result
    In [126]: t3flat[-1, -1] = -1
    # input tensor also modified. Thus, flattened result is a view.
    In [127]: t3
    Out[127]: 
    tensor([[[12, 13, 14, 15],
             [16, 17, 18, 19]],
            [[20, 21, 22, 23],
             [24, 25, 26, 27]],
            [[28, 29, 30, 31],
             [32, 33, 34, -1]]])
    

    tidbit: torch.flatten() is the precursor to nn.Flatten() and its brethren nn.Unflatten() since it existed from the very beginning. Then, there was a legitimate use-case for nn.Flatten(), since this is a common requirement for almost all ConvNets (just before the softmax or elsewhere). So it was added later on in the PR #22245.

    There are also recent proposals to use nn.Flatten() in ResNets for model surgery.

    Thanks for contributing an answer to Stack Overflow!

    • Please be sure to answer the question. Provide details and share your research!

    But avoid

    • Asking for help, clarification, or responding to other answers.
    • Making statements based on opinion; back them up with references or personal experience.

    To learn more, see our tips on writing great answers.

  •