1.学习率调整策略

Q:_LRScheduler的主要属性和方法有哪些?

  • 属性:
  • optimizer:关联的优化器
  • last_epoch:记录epoch数
  • base_lrs:记录初始学习率
  • 方法:
  • step():更新下一个epoch的学习率
  • get_lr():虚函数,计算下一个epoch的学习率

Q:如何等间隔调整学习率?

  • torch.optim.lr_scheduler.StepLR(optimizer, step_size, gamma=0.1, last_epoch=-1)
  • 调整方式:lr = lr * gamma
  • step_size:调整间隔数
  • gamma:调整系数

Q:StepLR代码示例

In [2]:
import torch
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
torch.manual_seed(1)

LR = 0.1
iteration = 10
max_epoch = 200
# ------------------------------ fake data and optimizer  ------------------------------

weights = torch.randn((1), requires_grad=True)
target = torch.zeros((1))

optimizer = optim.SGD([weights], lr=LR, momentum=0.9)

# ------------------------------ 1 Step LR ------------------------------

scheduler_lr = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1)  # 设置学习率下降策略

lr_list, epoch_list = list(), list()
for epoch in range(max_epoch):

    # 获取当前lr,新版本用 get_last_lr()函数,旧版本用get_lr()函数,具体看UserWarning
    lr_list.append(scheduler_lr.get_lr())
    epoch_list.append(epoch)

    for i in range(iteration):

        loss = torch.pow((weights - target), 2)
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

    scheduler_lr.step()

plt.plot(epoch_list, lr_list, label="Step LR Scheduler")
plt.xlabel("Epoch")
plt.ylabel("Learning rate")
plt.legend()
plt.show()

Q:如何按给定间隔调整学习率?

  • torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones, gamma=0.1, last_epoch=-1)
  • milestones:设定调整时刻数
  • gamma:调整系数

Q:MultiStepLR的代码示例

In [3]:
import torch
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
torch.manual_seed(1)

LR = 0.1
iteration = 10
max_epoch = 200
# ------------------------------ fake data and optimizer  ------------------------------

weights = torch.randn((1), requires_grad=True)
target = torch.zeros((1))

optimizer = optim.SGD([weights], lr=LR, momentum=0.9)

# ------------------------------ 2 Multi Step LR ------------------------------
milestones = [50, 125, 160]
scheduler_lr = optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.1)

lr_list, epoch_list = list(), list()
for epoch in range(max_epoch):

    lr_list.append(scheduler_lr.get_lr())
    epoch_list.append(epoch)

    for i in range(iteration):

        loss = torch.pow((weights - target), 2)
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

    scheduler_lr.step()

plt.plot(epoch_list, lr_list, label="Multi Step LR Scheduler\nmilestones:{}".format(milestones))
plt.xlabel("Epoch")
plt.ylabel("Learning rate")
plt.legend()
plt.show()

Q:如何按指数衰减调整学习率?

  • torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma, last_epoch=-1)
  • gamma:指数的底,通常设为一个接近1的数字,如0.95
  • 调整方式:lr = lr * gamma ** epoch

Q:ExponentialLR代码示例

In [4]:
import torch
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
torch.manual_seed(1)

LR = 0.1
iteration = 10
max_epoch = 200
# ------------------------------ fake data and optimizer  ------------------------------

weights = torch.randn((1), requires_grad=True)
target = torch.zeros((1))

optimizer = optim.SGD([weights], lr=LR, momentum=0.9)

# ------------------------------ 3 Exponential LR ------------------------------
gamma = 0.95
scheduler_lr = optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma)

lr_list, epoch_list = list(), list()
for epoch in range(max_epoch):

    lr_list.append(scheduler_lr.get_lr())
    epoch_list.append(epoch)

    for i in range(iteration):

        loss = torch.pow((weights - target), 2)
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

    scheduler_lr.step()

plt.plot(epoch_list, lr_list, label="Exponential LR Scheduler\ngamma:{}".format(gamma))
plt.xlabel("Epoch")
plt.ylabel("Learning rate")
plt.legend()
plt.show()

Q:如何按余弦周期调整学习率?

  • torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max, eta_min=0, last_epoch=-1)
  • T_max:下降周期
  • eta_min:学习率下限
  • 调整方式:$$\eta_{t}=\eta_{\min }+\frac{1}{2}\left(\eta_{\max }-\eta_{\min }\right)\left(1+\cos \left(\frac{T_{c u r}}{T_{\max }} \pi\right)\right)$$

Q:CosineAnnealingLR代码示例

In [5]:
import torch
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
torch.manual_seed(1)

LR = 0.1
iteration = 10
max_epoch = 200
# ------------------------------ fake data and optimizer  ------------------------------

weights = torch.randn((1), requires_grad=True)
target = torch.zeros((1))

optimizer = optim.SGD([weights], lr=LR, momentum=0.9)

# ------------------------------ 4 Cosine Annealing LR ------------------------------
t_max = 50
scheduler_lr = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=t_max, eta_min=0.)

lr_list, epoch_list = list(), list()
for epoch in range(max_epoch):

    lr_list.append(scheduler_lr.get_lr())
    epoch_list.append(epoch)

    for i in range(iteration):

        loss = torch.pow((weights - target), 2)
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

    scheduler_lr.step()

plt.plot(epoch_list, lr_list, label="CosineAnnealingLR Scheduler\nT_max:{}".format(t_max))
plt.xlabel("Epoch")
plt.ylabel("Learning rate")
plt.legend()
plt.show()

Q:如何监控指标, 当指标不再变化则调整学习率?

  • torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=False, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08)
  • mode:min/max两种模式:如min代表不下降就调整
  • factor:调整系数
  • patience:“耐心”,接受几次不变化
  • cooldown:“冷却时间”,停止监控一段时间
  • verbose:是否打印日志
  • min_lr:学习率下限
  • eps:学习率衰减最小值

Q:ReduceLROnPlateau代码示例

In [6]:
import torch
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
torch.manual_seed(1)

LR = 0.1
iteration = 10
max_epoch = 200
# ------------------------------ fake data and optimizer  ------------------------------

weights = torch.randn((1), requires_grad=True)
target = torch.zeros((1))

optimizer = optim.SGD([weights], lr=LR, momentum=0.9)

# ------------------------------ 5 Reduce LR On Plateau ------------------------------
loss_value = 0.5
accuray = 0.9

factor = 0.1
mode = "min"
patience = 10
cooldown = 10
min_lr = 1e-4
verbose = True

scheduler_lr = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=factor, mode=mode, patience=patience,
                                                    cooldown=cooldown, min_lr=min_lr, verbose=verbose)

for epoch in range(max_epoch):
    for i in range(iteration):

        # train(...)

        optimizer.step()
        optimizer.zero_grad()

    if epoch == 5:
        loss_value = 0.4

    scheduler_lr.step(loss_value)
Epoch    16: reducing learning rate of group 0 to 1.0000e-02.
Epoch    37: reducing learning rate of group 0 to 1.0000e-03.
Epoch    58: reducing learning rate of group 0 to 1.0000e-04.

Q:如何自定义学习率调整策略?

  • torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=-1)
  • lr_lambda:function or list,若为list,里面每个元素须为function

Q:LambdaLR的代码示例

In [7]:
import torch
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
torch.manual_seed(1)

LR = 0.1
iteration = 10
max_epoch = 200
# ------------------------------ fake data and optimizer  ------------------------------

weights = torch.randn((1), requires_grad=True)
target = torch.zeros((1))

optimizer = optim.SGD([weights], lr=LR, momentum=0.9)

# ------------------------------ 6 lambda ------------------------------
lr_init = 0.1

weights_1 = torch.randn((6, 3, 5, 5))
weights_2 = torch.ones((5, 5))

optimizer = optim.SGD([
    {'params': [weights_1]},
    {'params': [weights_2]}], lr=lr_init)

lambda1 = lambda epoch: 0.1 ** (epoch // 20)
lambda2 = lambda epoch: 0.95 ** epoch

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=[lambda1, lambda2])

lr_list, epoch_list = list(), list()
for epoch in range(max_epoch):
    for i in range(iteration):

        # train(...)

        optimizer.step()
        optimizer.zero_grad()

    scheduler.step()

    lr_list.append(scheduler.get_lr())
    epoch_list.append(epoch)

    print('epoch:{:5d}, lr:{}'.format(epoch, scheduler.get_lr()))

plt.plot(epoch_list, [i[0] for i in lr_list], label="lambda 1")
plt.plot(epoch_list, [i[1] for i in lr_list], label="lambda 2")
plt.xlabel("Epoch")
plt.ylabel("Learning Rate")
plt.title("LambdaLR")
plt.legend()
plt.show()
epoch:    0, lr:[0.1, 0.095]
epoch:    1, lr:[0.1, 0.09025]
epoch:    2, lr:[0.1, 0.0857375]
epoch:    3, lr:[0.1, 0.081450625]
epoch:    4, lr:[0.1, 0.07737809374999999]
epoch:    5, lr:[0.1, 0.07350918906249998]
epoch:    6, lr:[0.1, 0.06983372960937498]
epoch:    7, lr:[0.1, 0.06634204312890622]
epoch:    8, lr:[0.1, 0.0630249409724609]
epoch:    9, lr:[0.1, 0.05987369392383787]
epoch:   10, lr:[0.1, 0.05688000922764597]
epoch:   11, lr:[0.1, 0.05403600876626367]
epoch:   12, lr:[0.1, 0.051334208327950485]
epoch:   13, lr:[0.1, 0.04876749791155296]
epoch:   14, lr:[0.1, 0.046329123015975304]
epoch:   15, lr:[0.1, 0.04401266686517654]
epoch:   16, lr:[0.1, 0.04181203352191771]
epoch:   17, lr:[0.1, 0.039721431845821824]
epoch:   18, lr:[0.1, 0.03773536025353073]
epoch:   19, lr:[0.010000000000000002, 0.03584859224085419]
epoch:   20, lr:[0.010000000000000002, 0.03405616262881148]
epoch:   21, lr:[0.010000000000000002, 0.0323533544973709]
epoch:   22, lr:[0.010000000000000002, 0.03073568677250236]
epoch:   23, lr:[0.010000000000000002, 0.02919890243387724]
epoch:   24, lr:[0.010000000000000002, 0.027738957312183378]
epoch:   25, lr:[0.010000000000000002, 0.026352009446574204]
epoch:   26, lr:[0.010000000000000002, 0.025034408974245494]
epoch:   27, lr:[0.010000000000000002, 0.023782688525533217]
epoch:   28, lr:[0.010000000000000002, 0.022593554099256556]
epoch:   29, lr:[0.010000000000000002, 0.02146387639429373]
epoch:   30, lr:[0.010000000000000002, 0.02039068257457904]
epoch:   31, lr:[0.010000000000000002, 0.019371148445850087]
epoch:   32, lr:[0.010000000000000002, 0.018402591023557582]
epoch:   33, lr:[0.010000000000000002, 0.017482461472379703]
epoch:   34, lr:[0.010000000000000002, 0.016608338398760716]
epoch:   35, lr:[0.010000000000000002, 0.01577792147882268]
epoch:   36, lr:[0.010000000000000002, 0.014989025404881546]
epoch:   37, lr:[0.010000000000000002, 0.014239574134637467]
epoch:   38, lr:[0.010000000000000002, 0.013527595427905593]
epoch:   39, lr:[0.0010000000000000002, 0.012851215656510312]
epoch:   40, lr:[0.0010000000000000002, 0.012208654873684797]
epoch:   41, lr:[0.0010000000000000002, 0.011598222130000557]
epoch:   42, lr:[0.0010000000000000002, 0.011018311023500529]
epoch:   43, lr:[0.0010000000000000002, 0.010467395472325502]
epoch:   44, lr:[0.0010000000000000002, 0.009944025698709225]
epoch:   45, lr:[0.0010000000000000002, 0.009446824413773765]
epoch:   46, lr:[0.0010000000000000002, 0.008974483193085076]
epoch:   47, lr:[0.0010000000000000002, 0.00852575903343082]
epoch:   48, lr:[0.0010000000000000002, 0.00809947108175928]
epoch:   49, lr:[0.0010000000000000002, 0.007694497527671315]
epoch:   50, lr:[0.0010000000000000002, 0.007309772651287749]
epoch:   51, lr:[0.0010000000000000002, 0.006944284018723362]
epoch:   52, lr:[0.0010000000000000002, 0.0065970698177871935]
epoch:   53, lr:[0.0010000000000000002, 0.006267216326897833]
epoch:   54, lr:[0.0010000000000000002, 0.005953855510552941]
epoch:   55, lr:[0.0010000000000000002, 0.005656162735025293]
epoch:   56, lr:[0.0010000000000000002, 0.005373354598274029]
epoch:   57, lr:[0.0010000000000000002, 0.005104686868360327]
epoch:   58, lr:[0.0010000000000000002, 0.004849452524942311]
epoch:   59, lr:[0.00010000000000000003, 0.004606979898695194]
epoch:   60, lr:[0.00010000000000000003, 0.004376630903760435]
epoch:   61, lr:[0.00010000000000000003, 0.004157799358572413]
epoch:   62, lr:[0.00010000000000000003, 0.003949909390643792]
epoch:   63, lr:[0.00010000000000000003, 0.003752413921111602]
epoch:   64, lr:[0.00010000000000000003, 0.003564793225056022]
epoch:   65, lr:[0.00010000000000000003, 0.0033865535638032207]
epoch:   66, lr:[0.00010000000000000003, 0.0032172258856130592]
epoch:   67, lr:[0.00010000000000000003, 0.0030563645913324064]
epoch:   68, lr:[0.00010000000000000003, 0.002903546361765786]
epoch:   69, lr:[0.00010000000000000003, 0.0027583690436774966]
epoch:   70, lr:[0.00010000000000000003, 0.0026204505914936217]
epoch:   71, lr:[0.00010000000000000003, 0.0024894280619189406]
epoch:   72, lr:[0.00010000000000000003, 0.0023649566588229936]
epoch:   73, lr:[0.00010000000000000003, 0.0022467088258818434]
epoch:   74, lr:[0.00010000000000000003, 0.002134373384587751]
epoch:   75, lr:[0.00010000000000000003, 0.0020276547153583635]
epoch:   76, lr:[0.00010000000000000003, 0.0019262719795904452]
epoch:   77, lr:[0.00010000000000000003, 0.001829958380610923]
epoch:   78, lr:[0.00010000000000000003, 0.0017384604615803768]
epoch:   79, lr:[1.0000000000000003e-05, 0.001651537438501358]
epoch:   80, lr:[1.0000000000000003e-05, 0.00156896056657629]
epoch:   81, lr:[1.0000000000000003e-05, 0.0014905125382474755]
epoch:   82, lr:[1.0000000000000003e-05, 0.0014159869113351015]
epoch:   83, lr:[1.0000000000000003e-05, 0.0013451875657683465]
epoch:   84, lr:[1.0000000000000003e-05, 0.001277928187479929]
epoch:   85, lr:[1.0000000000000003e-05, 0.0012140317781059325]
epoch:   86, lr:[1.0000000000000003e-05, 0.0011533301892006358]
epoch:   87, lr:[1.0000000000000003e-05, 0.001095663679740604]
epoch:   88, lr:[1.0000000000000003e-05, 0.0010408804957535737]
epoch:   89, lr:[1.0000000000000003e-05, 0.000988836470965895]
epoch:   90, lr:[1.0000000000000003e-05, 0.0009393946474176001]
epoch:   91, lr:[1.0000000000000003e-05, 0.0008924249150467202]
epoch:   92, lr:[1.0000000000000003e-05, 0.0008478036692943841]
epoch:   93, lr:[1.0000000000000003e-05, 0.0008054134858296649]
epoch:   94, lr:[1.0000000000000003e-05, 0.0007651428115381816]
epoch:   95, lr:[1.0000000000000003e-05, 0.0007268856709612725]
epoch:   96, lr:[1.0000000000000003e-05, 0.0006905413874132089]
epoch:   97, lr:[1.0000000000000003e-05, 0.0006560143180425484]
epoch:   98, lr:[1.0000000000000003e-05, 0.0006232136021404209]
epoch:   99, lr:[1.0000000000000004e-06, 0.0005920529220333997]
epoch:  100, lr:[1.0000000000000004e-06, 0.0005624502759317298]
epoch:  101, lr:[1.0000000000000004e-06, 0.0005343277621351433]
epoch:  102, lr:[1.0000000000000004e-06, 0.0005076113740283861]
epoch:  103, lr:[1.0000000000000004e-06, 0.00048223080532696673]
epoch:  104, lr:[1.0000000000000004e-06, 0.0004581192650606184]
epoch:  105, lr:[1.0000000000000004e-06, 0.00043521330180758743]
epoch:  106, lr:[1.0000000000000004e-06, 0.00041345263671720806]
epoch:  107, lr:[1.0000000000000004e-06, 0.0003927800048813476]
epoch:  108, lr:[1.0000000000000004e-06, 0.00037314100463728026]
epoch:  109, lr:[1.0000000000000004e-06, 0.00035448395440541624]
epoch:  110, lr:[1.0000000000000004e-06, 0.0003367597566851454]
epoch:  111, lr:[1.0000000000000004e-06, 0.0003199217688508881]
epoch:  112, lr:[1.0000000000000004e-06, 0.0003039256804083437]
epoch:  113, lr:[1.0000000000000004e-06, 0.0002887293963879265]
epoch:  114, lr:[1.0000000000000004e-06, 0.00027429292656853016]
epoch:  115, lr:[1.0000000000000004e-06, 0.00026057828024010366]
epoch:  116, lr:[1.0000000000000004e-06, 0.0002475493662280985]
epoch:  117, lr:[1.0000000000000004e-06, 0.00023517189791669353]
epoch:  118, lr:[1.0000000000000004e-06, 0.0002234133030208588]
epoch:  119, lr:[1.0000000000000005e-07, 0.00021224263786981585]
epoch:  120, lr:[1.0000000000000005e-07, 0.00020163050597632508]
epoch:  121, lr:[1.0000000000000005e-07, 0.0001915489806775088]
epoch:  122, lr:[1.0000000000000005e-07, 0.00018197153164363337]
epoch:  123, lr:[1.0000000000000005e-07, 0.00017287295506145168]
epoch:  124, lr:[1.0000000000000005e-07, 0.00016422930730837908]
epoch:  125, lr:[1.0000000000000005e-07, 0.00015601784194296014]
epoch:  126, lr:[1.0000000000000005e-07, 0.00014821694984581212]
epoch:  127, lr:[1.0000000000000005e-07, 0.0001408061023535215]
epoch:  128, lr:[1.0000000000000005e-07, 0.00013376579723584542]
epoch:  129, lr:[1.0000000000000005e-07, 0.00012707750737405313]
epoch:  130, lr:[1.0000000000000005e-07, 0.00012072363200535048]
epoch:  131, lr:[1.0000000000000005e-07, 0.00011468745040508295]
epoch:  132, lr:[1.0000000000000005e-07, 0.0001089530778848288]
epoch:  133, lr:[1.0000000000000005e-07, 0.00010350542399058736]
epoch:  134, lr:[1.0000000000000005e-07, 9.833015279105799e-05]
epoch:  135, lr:[1.0000000000000005e-07, 9.341364515150508e-05]
epoch:  136, lr:[1.0000000000000005e-07, 8.874296289392982e-05]
epoch:  137, lr:[1.0000000000000005e-07, 8.430581474923332e-05]
epoch:  138, lr:[1.0000000000000005e-07, 8.009052401177165e-05]
epoch:  139, lr:[1.0000000000000004e-08, 7.608599781118307e-05]
epoch:  140, lr:[1.0000000000000004e-08, 7.228169792062392e-05]
epoch:  141, lr:[1.0000000000000004e-08, 6.866761302459272e-05]
epoch:  142, lr:[1.0000000000000004e-08, 6.523423237336307e-05]
epoch:  143, lr:[1.0000000000000004e-08, 6.197252075469492e-05]
epoch:  144, lr:[1.0000000000000004e-08, 5.8873894716960165e-05]
epoch:  145, lr:[1.0000000000000004e-08, 5.593019998111216e-05]
epoch:  146, lr:[1.0000000000000004e-08, 5.313368998205655e-05]
epoch:  147, lr:[1.0000000000000004e-08, 5.0477005482953716e-05]
epoch:  148, lr:[1.0000000000000004e-08, 4.795315520880603e-05]
epoch:  149, lr:[1.0000000000000004e-08, 4.555549744836572e-05]
epoch:  150, lr:[1.0000000000000004e-08, 4.327772257594744e-05]
epoch:  151, lr:[1.0000000000000004e-08, 4.1113836447150066e-05]
epoch:  152, lr:[1.0000000000000004e-08, 3.905814462479256e-05]
epoch:  153, lr:[1.0000000000000004e-08, 3.710523739355293e-05]
epoch:  154, lr:[1.0000000000000004e-08, 3.524997552387528e-05]
epoch:  155, lr:[1.0000000000000004e-08, 3.3487476747681514e-05]
epoch:  156, lr:[1.0000000000000004e-08, 3.181310291029744e-05]
epoch:  157, lr:[1.0000000000000004e-08, 3.0222447764782564e-05]
epoch:  158, lr:[1.0000000000000004e-08, 2.8711325376543437e-05]
epoch:  159, lr:[1.0000000000000005e-09, 2.7275759107716264e-05]
epoch:  160, lr:[1.0000000000000005e-09, 2.5911971152330445e-05]
epoch:  161, lr:[1.0000000000000005e-09, 2.4616372594713925e-05]
epoch:  162, lr:[1.0000000000000005e-09, 2.3385553964978226e-05]
epoch:  163, lr:[1.0000000000000005e-09, 2.2216276266729317e-05]
epoch:  164, lr:[1.0000000000000005e-09, 2.110546245339285e-05]
epoch:  165, lr:[1.0000000000000005e-09, 2.0050189330723204e-05]
epoch:  166, lr:[1.0000000000000005e-09, 1.9047679864187045e-05]
epoch:  167, lr:[1.0000000000000005e-09, 1.809529587097769e-05]
epoch:  168, lr:[1.0000000000000005e-09, 1.7190531077428805e-05]
epoch:  169, lr:[1.0000000000000005e-09, 1.6331004523557364e-05]
epoch:  170, lr:[1.0000000000000005e-09, 1.5514454297379498e-05]
epoch:  171, lr:[1.0000000000000005e-09, 1.4738731582510519e-05]
epoch:  172, lr:[1.0000000000000005e-09, 1.4001795003384993e-05]
epoch:  173, lr:[1.0000000000000005e-09, 1.3301705253215743e-05]
epoch:  174, lr:[1.0000000000000005e-09, 1.2636619990554954e-05]
epoch:  175, lr:[1.0000000000000005e-09, 1.2004788991027206e-05]
epoch:  176, lr:[1.0000000000000005e-09, 1.1404549541475845e-05]
epoch:  177, lr:[1.0000000000000005e-09, 1.0834322064402054e-05]
epoch:  178, lr:[1.0000000000000005e-09, 1.029260596118195e-05]
epoch:  179, lr:[1.0000000000000006e-10, 9.777975663122852e-06]
epoch:  180, lr:[1.0000000000000006e-10, 9.28907687996671e-06]
epoch:  181, lr:[1.0000000000000006e-10, 8.824623035968373e-06]
epoch:  182, lr:[1.0000000000000006e-10, 8.383391884169954e-06]
epoch:  183, lr:[1.0000000000000006e-10, 7.964222289961456e-06]
epoch:  184, lr:[1.0000000000000006e-10, 7.566011175463383e-06]
epoch:  185, lr:[1.0000000000000006e-10, 7.187710616690214e-06]
epoch:  186, lr:[1.0000000000000006e-10, 6.828325085855702e-06]
epoch:  187, lr:[1.0000000000000006e-10, 6.486908831562916e-06]
epoch:  188, lr:[1.0000000000000006e-10, 6.16256338998477e-06]
epoch:  189, lr:[1.0000000000000006e-10, 5.854435220485532e-06]
epoch:  190, lr:[1.0000000000000006e-10, 5.5617134594612554e-06]
epoch:  191, lr:[1.0000000000000006e-10, 5.283627786488193e-06]
epoch:  192, lr:[1.0000000000000006e-10, 5.0194463971637825e-06]
epoch:  193, lr:[1.0000000000000006e-10, 4.768474077305593e-06]
epoch:  194, lr:[1.0000000000000006e-10, 4.5300503734403135e-06]
epoch:  195, lr:[1.0000000000000006e-10, 4.3035478547682975e-06]
epoch:  196, lr:[1.0000000000000006e-10, 4.088370462029883e-06]
epoch:  197, lr:[1.0000000000000006e-10, 3.883951938928388e-06]
epoch:  198, lr:[1.0000000000000006e-10, 3.6897543419819688e-06]
epoch:  199, lr:[1.0000000000000006e-11, 3.5052666248828703e-06]

2.TensorBoard简介与安装

Q:TensorBoard安装和启动命令?

  • 安装:conda install tensorboard typing-extensions
  • 启动:tensorboard --logdir=./runs

Q:TensorBoard测试代码

In [9]:
import numpy as np
from torch.utils.tensorboard import SummaryWriter


writer = SummaryWriter(comment='test_tensorboard')

for x in range(100):

    writer.add_scalar('y=2x', x * 2, x)
    writer.add_scalar('y=pow(2, x)',  2 ** x, x)
    
    writer.add_scalars('data/scalar_group', {"xsinx": x * np.sin(x),
                                             "xcosx": x * np.cos(x),
                                             "arctanx": np.arctan(x)}, x)
writer.close()

3.TensorBoard使用(一)

Q:SummaryWriter的功能和属性是什么?

  • 功能:提供创建event file的高级接口
  • class SummaryWriter(object):
      def __init__(self, log_dir=None, comment='', purge_step=None, max_queue=10, flush_secs=120, filename_suffix=’’)
    
  • log_dir:event file输出文件夹
  • comment:不指定log_dir时,文件夹后缀
  • filename_suffix:event file文件名后缀
In [ ]:
 
In [ ]:
 
In [ ]: