Collectives™ on Stack Overflow

Find centralized, trusted content and collaborate around the technologies you use most.

Learn more about Collectives

Teams

Q&A for work

Connect and share knowledge within a single location that is structured and easy to search.

Learn more about Teams

I would like to 'translate' a PyTorch model to another framework (non-tf/keras).

I'm trying to take a pytorch model, and automate the translation to the other framework, which contains similar types of layers (i.e. conv2d, dense,...).

Is there a way from pytorch directly, or through onnx to retrieve a models layers, their types, shapes and connections ? (Weights are not important so far)

If the model is sequential then you can infer the architecture of the network from its layers directly. For any model that is more complex, i.e. contains other logic than purely sequential layers, then you won't be able to read that from the layers themselves. In other words, what you've defined as "connections" in your question are only available to the user as python code, inside the forward definition of that model. Ivan Feb 9, 2022 at 22:33 Thanks, @Ivan. Yes I understand that, but isn't ONNX used to export models from frameworks, acting as an intermediary between frameworks? If that's the case, I would imagine the whole logic - including in the forward call - must be somehow saved in the ONNX file? (As I understand, ONNX requires a forward pass, to generate the computational graph as to save it fully?) user452306 Feb 9, 2022 at 22:43 @user452306 you are correct you can inspect an ONNX graph and get all that information, the main thing is you will get ONNX operators that are not always mapped 1:1 from torch, nn.Linear is often a Gemm in ONNX for example but can sometimes show up as MatMul and Add (for the bias). ONNX operator reference: github.com/onnx/onnx/blob/main/docs/Operators.md IceTDrinker Feb 11, 2022 at 14:25 @IceTDrinker, Thanks! Yeah, I've seen the list of operators, and was able to access them through the ONNX graph, are you aware if we're able to access / retrieve the shapes of these layers / operations? As well as how the layers are connected to eachother? (i.e. for skip-connections or such) user452306 Feb 11, 2022 at 16:57 @user452306 yes in ONNX each node outputs are named you can check node.output (it's a list of strings) and it has a list calle input node.input, the string at idx i indicates which previous output goes in that idx i input. For the shapes there is something called shape inference in ONNX github.com/onnx/onnx/blob/main/docs/ShapeInference.md and for python github.com/onnx/onnx/blob/main/docs/… I don't remember how to extract the shape infos, but it should help you. I will put the infos in a proper answer IceTDrinker Feb 11, 2022 at 17:34

From discussion in comments on your question:

each node in onnx has a list of named inputs and a list of named outputs.

For the input list accessed with node.input you have for each input index either the graph input_name that feeds that input or the name of a previous output that feeds that input. There are also initializers which are onnx parameters.

# model is an onnx model
graph = model.graph
# graph inputs
for input_name in graph.input:
    print(input_name)
# graph parameters
for init in graph.init:
    print(init.name)
# graph outputs
for output_name in graph.output:
    print(output_name)
# iterate over nodes
for node in graph.node:
    # node inputs
    for idx, node_input_name in enumerate(node.input):
        print(idx, node_input_name)
    # node outputs
    for idx, node_output_name in enumerate(node.output):
        print(idx, node_output_name)

Shape inference is talked about here and for python here

The gist for python is found here

Reproducing the gist from 3:

from onnx import shape_inference
inferred_model = shape_inference.infer_shapes(original_model)

and find the shape info in inferred_model.graph.value_info.

You can also use netron or from GitHub to have a visual representation of that information.

Is there any way to get the data type of the operands in each layer like is it int8, fp16, or fp32? I know how to do this using Netron by manually going to each layer, but I want to automate the process and hence want a python script method. – Madhuparna Bhowmik Jul 10, 2022 at 22:46 @MadhuparnaBhowmik IIRC you need to check some value_info in the graph to know what type the input/output tensors are – IceTDrinker Jul 11, 2022 at 10:04

Thanks for contributing an answer to Stack Overflow!

  • Please be sure to answer the question. Provide details and share your research!

But avoid

  • Asking for help, clarification, or responding to other answers.
  • Making statements based on opinion; back them up with references or personal experience.

To learn more, see our tips on writing great answers.