Git Product home page Git Product logo

pytorch-camp's Issues

lesson-21 / loss_acc_weights_grad.py

  1. 在158行,因为valid_curve.append(loss.item())是在读取验证集的for循环的外面,所以最后的valid_curve列表里面只有一个损失值(最后一次读取验证集数据得到的损失)。
  2. 在159行打印loss_val时,打印的是一个epoch里,每次读取batch_size个验证集样本,来计算batch_size个样本损失的均值loss.item(),读多少次就计算多少个均值loss.item(),这里读取了2次,之后将2次的均值求和得到loss_val,这样打印的应该不是一个epoch的损失均值把,应该是loss_val/len(valid_loader)才对吧。
  3. 在163行,np.mean(valid_curve)只有一个数据,求均值还是最后一次读取的损失。
  4. 我的观点:在158行前面加上,loss_val_epoch=loss_val/len(valid_loader)。之后都用loss_val_epoch。

关于人民币二分类正确Label判断的问题

在人民币二分类任务中,训练模型时,有这两行代码

# train
correct += (predicted == labels).squeeze().sum().numpy()
# valid
correct_val += (predicted == labels).squeeze().sum().numpy()

我不太理解(predicted == labels)这个之后为什么需要加一个squeeze()。
我在python console进行了实验,验证不加squeeze()直接进行sum()和numpy()也可以得到同样的结果

# 预测的mini_batch中的结果
predicted = torch.tensor([0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1])
# ground truth
labels = torch.tensor([0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0])
# 不加squeeze(), 不去掉为1的维度
(predicted == labels).sum().numpy()
Out[33]: array(14, dtype=int64)
# 加 squeeze(), 去掉为1的维度
(predicted == labels).squeeze().sum().numpy()
Out[34]: array(14, dtype=int64)

所以,老师,这个squeeze()的作用是什么呢?

3_multi_gpu.py

37行'if not gpu_memory:'是不是应该改为'if gpu_memory:'。因为能计算显存才可执行下面的代码选择物理gpu。

lesson_8

为什么本节课的代码中的transform_invert函数没有declaration

lesson08

请问老师为什么fivecrop ,tencrop后面的处理不需要normalize处理呢

RMB_data_augmentation.py

test_data = RMBDataset(data_dir=test_dir, transform=valid_transform) 这一行用的valid_transform。
后面反transforms时 img = transform_invert(img_tensor, train_transform),用train_transform是不是不对呀,也需要用valid_transfrom叭

没有数据集

老师好,可以麻烦您提供以下数据集吗?对应的路径下找不到数据集,谢谢。

lesson-21 / loss_acc_weights_grad.py

补充:

  1. 在159和164行,计算准确率时,继续用correct / total,这个不是验证集的准确率叭,应该是correct_val / total_val

finetune_resnet18.py

  1. 在finetune_resnet18.py中的166行,loss_mean = 0.是否多余。
    因为当能运行到这一行时,就表示已经跑了一个epoch,将下面的代码执行完,就会回到132行进行下一个epoch,之后执行到134行,'loss_mean'也会重新赋值为0.。所以166我认为是多余的。

  2. 我下载的数据里,训练集里面ants是124张,bees是121张。
    BATCH_SIZE=16时,ants和bees读取8次左右就可以看着是一个epoch。您给的代码里log_interval=10。这样是否冗余?还是我这边数据集图片的张数不对?

关于test_data的路径问题

请问老师,为什么test_data下必须建一个为名100的文件夹,然后把图片放进100文件夹里才能进行验证,直接将验证图片放在test-data下就会出现报错呢?相关代码在哪里体现的呢?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.