Che cos’è l’overfitting?
L’overfitting è un comportamento del Machine Learning che si verifica quando il modello è così strettamente allineato ai dati di addestramento che non sa come rispondere ai nuovi dati. L’overfitting può verificarsi quando:
- Il modello di Machine Learning è troppo complesso; memorizza schemi molto sottili nei dati di addestramento che non generalizzano bene.
- La dimensione dei dati di addestramento è troppo piccola per la complessità del modello e/o contiene grandi quantità di informazioni irrilevanti.
È possibile evitare l’overfitting gestendo la complessità del modello e migliorando il set di dati di addestramento.
Overfitting vs. Underfitting
L’underfitting è il concetto opposto all’overfitting: il modello non si allinea bene ai dati di addestramento o non generalizza bene rispetto ai nuovi dati. L’overfitting e l’underfitting possono essere presenti sia nei modelli di classificazione che in quelli di regressione. La figura seguente illustra come il confine della decisione di classificazione e la linea di regressione seguano i dati di addestramento troppo da vicino per un modello sottoposto a overfitting e non abbastanza da vicino per un modello sottoposto a underfitting.
Se si considera solo l’errore calcolato di un modello di Machine Learning per i dati di addestramento, l’overfitting è più difficile da rilevare rispetto all’underfitting. Quindi, per evitare l’overfitting, è importante convalidare un modello di Machine Learning prima di utilizzarlo su dati di prova.
Errore |
Overfitting |
Fit esatto |
Underfitting |
Formazione |
Bassa |
Bassa |
Alta |
Test |
Alta |
Bassa |
Alta |
Utilizzando MATLAB® con Statistics and Machine Learning Toolbox™ e Deep Learning Toolbox™, è possibile prevenire l’overfitting dei modelli di Machine Learning e di Deep Learning. MATLAB fornisce funzioni e metodi specificamente progettati per evitare l’overfitting dei modelli. È possibile utilizzare questi strumenti durante l’addestramento o la messa a punto del modello per proteggerlo dall’overfitting.
Come evitare l’overfitting riducendo la complessità del modello
Con MATLAB è possibile addestrare modelli di Machine Learning e modelli di Deep Learning (come le CNN) da zero o sfruttare modelli di Deep Learning preaddestrati. Per evitare l’overfitting, esegui la convalida del modello per assicurarsi di scegliere un modello con il giusto livello di complessità per i dati o utilizza la regolarizzazione per ridurre la complessità del modello.
Convalida dei modelli
L’errore di un modello sottoposto a overfitting è basso se calcolato per i dati di addestramento. È buona norma convalidare il modello su un set di dati separato (cioè un set di dati di convalida) prima di introdurre nuovi dati. Per i modelli di Machine Learning di MATLAB, è possibile utilizzare la funzione cvpartition
per partizionare in modo casuale un set di dati in set di addestramento e di convalida. Per i modelli di modelli di Deep Learning, è possibile monitorare l’accuratezza della convalida durante l’addestramento. Il miglioramento della misura di accuratezza correttamente convalidata per i tuoi modelli attraverso la selezione del modello e la messa a punto degli iperparametri dovrebbe tradursi in una migliore accuratezza quando il modello vede nuovi dati.
La convalida incrociata è una tecnica di valutazione del modello utilizzata per valutare le prestazioni di un algoritmo di Machine Learning nel fare previsioni su insiemi di dati su cui non è stato addestrato. La convalida incrociata aiuta a scegliere un algoritmo non troppo complesso che possa causare overfitting. Utilizza la funzione crossval
per calcolare la stima dell’errore di convalida incrociata per i modelli di Machine Learning utilizzando le tecniche più diffuse di convalida incrociata, come k-fold (consente di partizionare i dati in sottoinsiemi k di dimensioni approssimativamente uguali scelti in modo casuale) e holdout (consente di partizionare in modo casuale i dati in esattamente due sottoinsiemi di una dimensione specificata).
Regolarizzazione
La regolarizzazione è una tecnica utilizzata per prevenire l’overfitting statistico in un modello di Machine Learning. Gli algoritmi di regolarizzazione prevedono solitamente l’applicazione di una penalità per la complessità o per l’irregolarità. Introducendo ulteriori informazioni nel modello, gli algoritmi di regolarizzazione possono gestire la multicollinearità e i predittori ridondanti rendendo il modello più mirato e accurato.
Per il Machine Learning, è possibile scegliere tra tre tecniche di regolarizzazione diffuse: “lasso” (norma L1), “ridge” (norma L2) ed “elastic net”, con diversi tipi di modelli lineari di Machine Learning. Per il Deep Learning, è possibile incrementare il fattore di regolarizzazione L2 nelle opzioni di addestramento specificate o utilizzare layer di dropout nella rete per evitare l’overfitting.
Esempi e consigli pratici
Come evitare l’overfitting migliorando il set di dati di addestramento
La convalida incrociata e la regolarizzazione impediscono l’overfitting gestendo la complessità del modello. Un altro approccio consiste nel migliorare il set di dati. I modelli di Deep Learning, in particolare, richiedono grandi quantità di dati per evitare l’overfitting.
Incremento dei dati
Quando la disponibilità di dati è limitata, l’incremento dei dati è un metodo per espandere artificialmente i punti dati del set di dati di addestramento aggiungendo versioni randomizzate dei dati esistenti al set. MATLAB consente di incrementare vari tipi di dati, tra cui dati immagine e audio. Ad esempio, è possibile incrementare i dati immagine randomizzando la scala e la rotazione delle immagini esistenti.
Generazione di dati
La generazione di dati sintetici è un altro metodo per espandere un set di dati. Con MATLAB è possibile generare dati sintetici utilizzando le reti generative avversarie (GAN) o i gemelli digitali (generazione di dati tramite la simulazione).
Pulizia dei dati
La rumorosità dei dati contribuisce all’overfitting. Un approccio comune per ridurre i punti dati indesiderati consiste nella rimozione degli outlier dai dati mediante la funzione rmoutliers
.