The Python Oracle

ReduceLROnPlateau fallback to the previous weights with the minimum acc_loss

Become part of the top 3% of the developers by applying to Toptal https://topt.al/25cXVn

--

Music by Eric Matyas
https://www.soundimage.org
Track title: Puzzle Game 2 Looping

--

Chapters
00:00 Question
00:50 Accepted answer (Score 4)
01:47 Answer 2 (Score 1)
02:09 Thank you

--

Full question
https://stackoverflow.com/questions/5222...

Answer 1 links:
[ReduceLROnPlateau]: https://github.com/keras-team/keras/blob...

--

Content licensed under CC BY-SA
https://meta.stackexchange.com/help/lice...

--

Tags
#python #keras

#avk47



ACCEPTED ANSWER

Score 5


Here's a working example following @nuric's direction:

from tensorflow.python.keras.callbacks import ReduceLROnPlateau
from tensorflow.python.platform import tf_logging as logging

class ReduceLRBacktrack(ReduceLROnPlateau):
    def __init__(self, best_path, *args, **kwargs):
        super(ReduceLRBacktrack, self).__init__(*args, **kwargs)
        self.best_path = best_path

    def on_epoch_end(self, epoch, logs=None):
        current = logs.get(self.monitor)
        if current is None:
            logging.warning('Reduce LR on plateau conditioned on metric `%s` '
                            'which is not available. Available metrics are: %s',
                             self.monitor, ','.join(list(logs.keys())))
        if not self.monitor_op(current, self.best): # not new best
            if not self.in_cooldown(): # and we're not in cooldown
                if self.wait+1 >= self.patience: # going to reduce lr
                    # load best model so far
                    print("Backtracking to best model before reducting LR")
                    self.model.load_weights(self.best_path)

        super().on_epoch_end(epoch, logs) # actually reduce LR

ModelCheckpoint call-back can be used to update the best model dump. e.g. pass the following two call-backs to model fit:

model_checkpoint_path = <path to checkpoint>
c1 = ModelCheckpoint(model_checkpoint_path, 
                     save_best_only=True,
                     monitor=...)
c2 = ReduceLRBacktrack(best_path=model_checkpoint_path, monitor=...)



ANSWER 2

Score 1


You could create a custom callback inheriting from ReduceLROnPlateau, something along the lines of:

class CheckpointLR(ReduceLROnPlateau):
   # override on_epoch_end()
   def on_epoch_end(self, epoch, logs=None):
     if not self.in_cooldown():
       temp = self.model.get_weights()
       self.model.set_weights(self.last_weights)
       self.last_weights = temp
     super().on_epoch_end(epoch, logs) # actually reduce LR