Pytorch. Snapshots Weights Averaging.

Okay, you have a number of checkpoints from a train-loop for 100500 epochs. Or you’ve carried some experiments with one architecture and changed only global parameters and now you have 7 saved .pth models in your folder.

But all of these models still can’t achieve 90% accuracy … 0.5–0.7% are missing.

Then how to reach the desired accuracy of 90%? The answer is to combine all models & average weights from snapshots.

Explanation.

Get a dictionary for each snapshot: parameters’ names and values.

Iterate on each parameter and set in new state_dict averaged value.

Load new state_dict into the model.

Computer Vision Engineer

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store