How to freeze specific neuron weights in a feedforward network during training process?

I am trying to train a simple feedforward network in which I need to freeze the weights and biases of certain neurons in a particular layer during training so that those weights do not change after each epoch. I am aware that it is possible to freeze an entire layer, but I am not sure how to proceed with weight specific freezing of few neurons.

 Risposta accettata

I don't think you can do it without the Deep Learning Toolbox. If you have the Deep Learning Toolbox, you can define a custom layer with mask parameters to indicate which weights and biases are to be frozen,
classdef MaskedFullyConnectedLayer < nnet.layer.Layer
%Defines a fully connected layer with masks to indicate which weights and
%biases are to be learned and which are to be "frozen". The gradient calculation
%with respect to frozen variables will always be masked to zero.
properties (Learnable)
Weights
Bias
end
properties
MaskWeights
MaskBias
InputSize
OutputSize
end
methods
function layer = MaskedFullyConnectedLayer (inputSize, outputSize, maskWeights, maskBias, name)
% Constructor for masked fully connected layer
layer.Name = name;
layer.Description = "Masked Fully Connected Layer with Frozen Weights/Bias";
layer.InputSize = inputSize;
layer.OutputSize = outputSize;
% Initialize weights and biases
layer.Weights = randn(outputSize, inputSize) * 0.01;
layer.Bias = zeros(outputSize, 1);
% Store masks
layer.MaskWeights = ~maskWeights;
layer.MaskBias = ~maskBias;
end
function Y = predict(layer, X)
Y = layer.Weights*X + layer.Bias;
end
function [dLdX, dLdW, dLdB] = backward(layer,X,~,dLdY,~)
xsiz=size(X);
X=permute(X,[1,3,2]);
dLdY=permute(dLdY,[1,3,2]);
% Compute gradients
dLdW = batchmtimes(dLdY,'none',X,'transpose'); % Gradient w.r.t. Weights
dLdB = batchmtimes(dLdY,'n',1,'n'); % Gradient w.r.t. Bias
% Apply masks to prevent updates to frozen parameters
dLdW = dLdW .* layer.MaskWeights;
dLdB = dLdB .* layer.MaskBias;
% Compute gradient w.r.t. input
dLdX = reshape( pagemtimes(layer.Weights,'transpose', dLdY,'none') ,xsiz);
end
end
end
function out = batchmtimes(X,transpX, Y,transpY)
%Assumes X Y already permuted with permute(__,[1,3,2])
out=sum(pagemtimes(X,transpX, Y,transpY),3);
end
An example of usage would be:
% Define input and output sizes
inputSize = 10;
outputSize = 5;
% Define frozen variables with logical masks (1 = frozen, 0 = trainable)
frozenWeights = false(outputSize,inputSize);
frozenWeights(1:3, :) = 1; % Freeze first 3 input rows
frozenBias = false(outputSize,1);
frozenkBias(1:2) = 1; % Freeze first 2 bias elements
% Create the custom layer
maskedLayer = MaskedFullyConnectedLayer(inputSize, outputSize, frozenWeights, frozenBias, "MaskedFC");
maskedLayer.Weights(frozenWeights)=___ ; %assign desired values to the frozen weights
maskedLayer.Bias(frozenBias)=___ ; %assign desired values to the frozen biases
% Add to a network (example with dummy layers)
layers = [
featureInputLayer(inputSize)
maskedLayer
reluLayer
fullyConnectedLayer(2)
softmaxLayer
classificationLayer
];
analyzeNetwork(layers);
% Create dummy data for training
XTrain = randn(100, inputSize);
YTrain = categorical(randi([1, 2], 100, 1));
% Train network
options = trainingOptions('adam', 'MaxEpochs', 5, 'Verbose', true);
net = trainNetwork(XTrain, YTrain, layers, options);

5 Commenti

Thank you for the detailed answer. I will try this and see if it fixes my problem.
You're welcome, but if it does fix the problem, please Accept-click the answer.
I tried to run the example code to understand how the masked layer works, but I keep getting this error:
Error using trainNetwork
Invalid network.
Caused by:
Layer 'MaskedFC': Error using the predict function in layer MaskedFullyConnectedLayer. The function threw an error and could not be executed.
Arrays have incompatible sizes for this operation.
Y = X' .* layer.Weights + layer.Bias;
Dimension of X was causing this problem and the transpose of X fixed it, but it still throws a dimension error no matter what I do. Do you know if I'm doing something wrong?
I edited my post with a corrected version. Please try again.
Yes, now it works. Thank you for your help.

Accedi per commentare.

Più risposte (0)

Categorie

Scopri di più su Deep Learning Toolbox in Centro assistenza e File Exchange

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by