Reinitialize a network with weights from previous training

11 visualizzazioni (ultimi 30 giorni)
Hi,
I have a custom neural network that I trained on an old data set. I now want to retrain the same model architecture but with new data. One way could be to append the old data with the new one and shuffle, but since the data is a bit different, I want to use some basic transfer learning.
For this purpose, I need to reinitialise my netowrk with the trained weights from the old data which I sold already in a mat file. I want to know how to use these trained weights with the new network., i.e. essentially create a new network (save structure as before), initialise this netowrk with weights from the previous trained network and retrain the network with new data.
Best Regards and thanks in advance. :)
Networklayer = [...
sequenceInputLayer(featureDimension)
fullyConnectedLayer(4*numHiddenUnits1)
reluLayer
fullyConnectedLayer(4*numHiddenUnits1)
reluLayer
fullyConnectedLayer(8*numHiddenUnits1)
reluLayer
gruLayer(LSTMStateNum,'OutputMode','sequence',InputWeightsInitializer='he',RecurrentWeightsInitializer='he')
fullyConnectedLayer(8*numHiddenUnits1)
reluLayer
fullyConnectedLayer(4*numHiddenUnits1)
reluLayer
fullyConnectedLayer(numResponses)
regressionLayer];

Risposte (1)

Udit06
Udit06 il 16 Ott 2023
Hi Vasu,
I understand that you want to use the weights of the model trained on an old dataset to retrain the same model architecture with a new dataset. You can follow the following steps to achieve the same:
  1. Create a new network with the same architecture as the old network.
  2. Initialize the weights of the new network with the trained weights from the old network using the "setwb" function.
new_network = setwb(new_network, weights); % Set the weights of the new network
3. Train the new network with the new data using the "trainNetwork" function.
new_network = trainNetwork(new_data, new_network); % Train the new network with new data
You can refer to the following MathWorks documentations to know more about "setwb" and "trainNetwork" functions respectively:
  1. https://www.mathworks.com/help/deeplearning/ref/setwb.html
  2. https://www.mathworks.com/help/deeplearning/ref/trainnetwork.html
I hope this helps.

Prodotti


Release

R2023a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by