importcollectionsascolimportmathfromtorchimportoptimfromtorch.optimimportlr_schedulerfrom.utilsimportSCHEDULER_REGISTRYdef_get_position_from_periods(iteration:int,cummulative_period:list[int])->int:""" Get position from a period list. Specifically, it returns the index of the right-closest number in the period list. For example, suppose ``cummulative_period`` is ``[100, 200, 300, 400]``. Then: * If ``iteration == 50``, return 0 * If ``iteration == 210``, return 2 * If ``iteration == 300``, return 2. Args: iteration: current iteration. cummulative_period: cummulative period list. Returns: The position of the right-closest number in the period list """fori,periodinenumerate(cummulative_period):ifiteration<=period:returnireturn0
[docs]@SCHEDULER_REGISTRY.registerclassCosineAnnealingRestartLR(lr_scheduler.LRScheduler):""" A cosine annealing with restarts LR scheduler. Example: Given .. code-block:: text periods = [10, 10, 10, 10] restart_weights = [1, 0.5, 0.5, 0.5] eta_min = 1e-7 Then the scheduler will have 4 cycles of 10 iterations each. At the 10th, 20th, and 30th, the scheduler will restart with the weights in ``restart_weights``. Args: optimizer: the optimizer. periods: period for each cosine annealing cycle. restart_weights: (optional) restarts weights at each restart iteration. eta_min: The minimum lr. Defaults to 0 last_epoch: Used in _LRScheduler. Defaults to -1. """def__init__(self,optimizer:optim.Optimizer,periods:list[int],restart_weights:list[int]|None=None,eta_min:float=0,last_epoch:int=-1,):"""Create the scheduler."""ifrestart_weightsisNone:restart_weights=[1]self._periods=periodsself._restart_weights=restart_weightsself._eta_min=eta_minassertlen(self._periods)==len(self._restart_weights),"periods and restart_weights should have the same length."self._cumulative_period=[sum(self._periods[0:i+1])foriinrange(len(self._periods))]super().__init__(optimizer,last_epoch)
[docs]defget_lr(self):"""Return the current learning rate."""idx=_get_position_from_periods(self.last_epoch,self._cumulative_period)current_weight=self._restart_weights[idx]nearest_restart=0ifidx==0elseself._cumulative_period[idx-1]current_period=self._periods[idx]return[self._eta_min+current_weight*0.5*(base_lr-self._eta_min)*(1+math.cos(math.pi*((self.last_epoch-nearest_restart)/current_period)))forbase_lrinself.base_lrs]
[docs]@SCHEDULER_REGISTRY.registerclassMultiStepRestartLR(lr_scheduler.LRScheduler):""" Multi-step with restarts LR scheduler. Args: optimizer: torch optimizer. milestones: iterations that will decrease learning rate. gamma: decrease ratio. Defaults to 0.1. restarts: (optional) restart iterations. restart_weights: (optional) restart weights at each restart iteration. last_epoch: used in _LRScheduler. Defaults to -1. """def__init__(self,optimizer:optim.Optimizer,milestones:list[int],gamma:float=0.1,restarts:list[int]|None=None,restart_weights:list[int]|None=None,last_epoch:int=-1,):"""Create the scheduler."""ifrestartsisNone:restarts=[0]ifrestart_weightsisNone:restart_weights=[1]self._milestones=col.Counter(milestones)self._gamma=gammaself._restarts=restartsself._restart_weights=restart_weightsassertlen(self._restarts)==len(self._restart_weights),"restarts and their weights do not match."super().__init__(optimizer,last_epoch)
[docs]defget_lr(self):"""Return the current learning rate."""ifself.last_epochinself._restarts:weight=self._restart_weights[self._restarts.index(self.last_epoch)]return[group["initial_lr"]*weightforgroupinself.optimizer.param_groups]ifself.last_epochnotinself._milestones:return[group["lr"]forgroupinself.optimizer.param_groups]return[group["lr"]*self._gamma**self._milestones[self.last_epoch]forgroupinself.optimizer.param_groups]