Importing pytorch models in matlab using importNetworkFromPyTorch
87 visualizzazioni (ultimi 30 giorni)
Mostra commenti meno recenti
Mohammed Saifur
il 10 Mag 2023
Risposto: MathWorks Deep Learning Toolbox Team
il 16 Apr 2024
Hello,
I am trying to import the pre-trained pytorch model in matlab using the importNetworkFromPyTorch command supported by deep learning toolbox. However I am getting an error as below
Error using pytorchmex
Traced model failed to load. Trace the model in the fully supported version of PyTorch as described in Deep
Learning Toolbox Converter for PyTorch Models.
Error in nnet.internal.cnn.pytorch_importer.architecture.ModuleManager/loadModule (line 28)
PropertyCell = nnet.internal.cnn.pytorch_importer.architecture.pytorchmex(this.Filename);
Error in nnet.internal.cnn.pytorch_importer.architecture.ModuleManager (line 16)
PropertyCell = loadModule(this);
Error in nnet.internal.cnn.pytorch_importer.architecture.util.translatePyTorchFile (line 9)
nnet.internal.cnn.pytorch_importer.architecture.ModuleManager(filename);
Error in nnet.internal.cnn.pytorch_importer.architecture.importNetworkFromPyTorch (line 18)
importManager = nnet.internal.cnn.pytorch_importer.architecture.util.translatePyTorchFile(filename,
customLayerPath);
Error in importNetworkFromPyTorch (line 36)
Network = nnet.internal.cnn.pytorch_importer.architecture.importNetworkFromPyTorch(modelfile, varargin{:});
Error in mnist2mat (line 1)
net = importNetworkFromPyTorch("mnist_cnn.pt");
0 Commenti
Risposta accettata
MathWorks Deep Learning Toolbox Team
il 16 Apr 2024
The model must be traced in PyTorch first before importing into MATLAB. Please see PyTorch documentation fo some details on how it's done. https://pytorch.org/docs/stable/generated/torch.jit.trace.html
You can also read this blog post for additional information: https://blogs.mathworks.com/deep-learning/2022/10/04/whats-new-in-interoperability-with-tensorflow-and-pytorch/
As a simple example, try something similar to the following in PyTorch:
# This example loads a pretrained PyTorch model from torchvision,
# traces it with example inputs, and saves the trace as a .pt file.
import torch
from torchvision import models
# Load the model with pretrained weights
model = models.mobilenet_v2(pretrained=True)
# Call "eval" to ensure that layers like batch norm and dropout are set to
# inference mode
model.eval()
# Move the model to the CPU
model.to("cpu")
# Create example inputs
X = torch.rand(1, 3, 224, 224)
# Trace model with the example input
traced_model = torch.jit.trace(model.forward, X)
# Save the traced model to a .pt file
traced_model.save('traced_mnasnet.pt')
0 Commenti
Più risposte (0)
Vedere anche
Categorie
Scopri di più su Pretrained Networks from External Platforms in Help Center e File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!