import torch from torch.optim import AdamW from models.lr_schedulers import get_scheduler MAX_TRAINING_STEPS = 100 WARMUP_STEPS = 80 INITIAL_LR = 5e-5 SCHEDULER_TYPE = "cosine" # "linear", "cosine" # --------------------------------------------- dummy_model = torch.nn.Linear(1, 1) dummy_optimizer = AdamW(dummy_model.parameters(), lr=INITIAL_LR) lr_scheduler = get_scheduler( name=SCHEDULER_TYPE, optimizer=dummy_optimizer, num_warmup_steps=WARMUP_STEPS, num_training_steps=MAX_TRAINING_STEPS, ) all_lrs = [] for step in range(MAX_TRAINING_STEPS): all_lrs.append(lr_scheduler.get_last_lr()[0]) lr_scheduler.step() print(all_lrs[79])