classdef MaskedFullyConnectedLayer < nnet.layer.Layer
properties (Learnable)
Weights
Bias
end
properties
MaskWeights
MaskBias
InputSize
OutputSize
end
methods
function layer = MaskedFullyConnectedLayer (inputSize, outputSize, maskWeights, maskBias, name)
layer.Name = name;
layer.Description = "Masked Fully Connected Layer with Frozen Weights/Bias";
layer.InputSize = inputSize;
layer.OutputSize = outputSize;
layer.Weights = randn(outputSize, inputSize) * 0.01;
layer.Bias = zeros(outputSize, 1);
layer.MaskWeights = ~maskWeights;
layer.MaskBias = ~maskBias;
end
function Y = predict(layer, X)
Y = layer.Weights*X + layer.Bias;
end
function [dLdX, dLdW, dLdB] = backward(layer,X,~,dLdY,~)
xsiz=size(X);
X=permute(X,[1,3,2]);
dLdY=permute(dLdY,[1,3,2]);
dLdW = batchmtimes(dLdY,'none',X,'transpose');
dLdB = batchmtimes(dLdY,'n',1,'n');
dLdW = dLdW .* layer.MaskWeights;
dLdB = dLdB .* layer.MaskBias;
dLdX = reshape( pagemtimes(layer.Weights,'transpose', dLdY,'none') ,xsiz);
end
end
end
function out = batchmtimes(X,transpX, Y,transpY)
out=sum(pagemtimes(X,transpX, Y,transpY),3);
end
frozenWeights = false(outputSize,inputSize);
frozenWeights(1:3, :) = 1;
frozenBias = false(outputSize,1);
maskedLayer = MaskedFullyConnectedLayer(inputSize, outputSize, frozenWeights, frozenBias, "MaskedFC");
maskedLayer.Weights(frozenWeights)=___ ;
maskedLayer.Bias(frozenBias)=___ ;
featureInputLayer(inputSize)
XTrain = randn(100, inputSize);
YTrain = categorical(randi([1, 2], 100, 1));
options = trainingOptions('adam', 'MaxEpochs', 5, 'Verbose', true);
net = trainNetwork(XTrain, YTrain, layers, options);