📚博客主页:knighthood2001
✨公众号:认知up吧 (目前正在带领大家一起提升认知,感兴趣可以来围观一下)
🎃知识星球:【认知up吧|成长|副业】介绍
❤️如遇文章付费,可先看看我公众号中是否发布免费文章❤️
🙏笔者水平有限,欢迎各位大佬指点,相互学习进步!
这块内容在前文讲过了,这里补充几个点。
import torch
import torch.utils.data as Data
from torchvision import transforms
from model import GoogLeNet, Inception
from torchvision.datasets import ImageFolder
from PIL import Image
def test_data_process():
# 定义数据集的路径
ROOT_TRAIN = r'data\test'
normalize = transforms.Normalize([0.162, 0.151, 0.138], [0.058, 0.052, 0.048])
# 定义数据集处理方法变量
test_transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor(), normalize])
# 加载数据集
test_data = ImageFolder(ROOT_TRAIN, transform=test_transform)
test_dataloader = Data.DataLoader(dataset=test_data,
batch_size=1,
shuffle=True,
num_workers=0)
return test_dataloader
- 在测试数据集加载中,batch_size设置成1,可以方便我们获得每个样本的精确预测结果。
- normalize如果计算过,可以加,否则,删掉这个即可。
normalize = transforms.Normalize([0.162, 0.151, 0.138], [0.058, 0.052, 0.048])