trainNetwork: è possibile accedere alla rete tramite la funzione di callback durante l'addestramento?

2 visualizzazioni (ultimi 30 giorni)
Sto usando la funzione trainNetwork() per allenare una rete neurale per un problema di regressione, in automatico questa funzione usa il valore di RMSE per valutare le performance della rete, tuttavia vorrei usare un tipo di errore diverso (e.g. MEA). Per fare ciò sto provando a lavorare con il campo 'OutputFcn' della funzione trainingOptions(), il quale permette di chiamare la funzione contenuta in tale campo ad ogni iterazione: la mia idea sarebbe quindi di verificare quando si sta compiendo la validazione e in quel caso calcolare il MEA, confrontare il MEA corrente con quello precedente e terminare il training, ponendo l'output della funzione a 1, quando si ottiene un MEA maggiore del precedente per più di N volte. Tuttavia per poter calcolare il MEA avrei bisogno di accedere alla rete allenata allo stato "attuale" (per calcolare l'errore tra predictions e target), è possibile? Esiste un metodo alternativo per usare il MEA come metrica per la valutazione delle performance della rete?

Risposte (1)

Kautuk Raj
Kautuk Raj il 26 Feb 2024
I capture the fact that you are looking to use Mean Absolute Error (MAE) as a loss function for evaluating the performance of your neural network during training, and you are considering the use of the OutputFcn in trainingOptions to implement this. The performance refers to the model's ability to accurately predict the target values. The loss function quantifies the difference between the predicted outputs of the network and the actual target values. During training, the loss function is used to guide the optimization process by providing a measure of error that the network needs to minimize.
Here is a MATLAB Answers post that demonstrates how to create a custom weighted loss function for regression using the Deep Learning Toolbox making use of the “custom training loop” feature: https://www.mathworks.com/matlabcentral/answers/735062-how-to-create-a-custom-weighted-loss-function-for-regression-using-deep-learning-toolbox

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!