pytorch识别图片之002 图片标准化
AI
#想象你有一堆照片,有的偏亮、有的偏暗,有的偏红、有的偏蓝。标准化就是把这些照片都调整到统一的“标准光照”下,让模型训练更稳定。
# (0.4915, 0.4823, 0.4468), # 均值
# (0.2470, 0.2435, 0.2616) # 标准差
#训练集(50,000张图像
transforms_cifar10 = datasets.CIFAR10(root=data_path,train=True,download=True,transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4915,0.4823,0.4468),(0.2470,0.2435,0.2616))]))
#测试集(10,000张图像)
transforms_cifar10_val = datasets.CIFAR10(root=data_path,train=False,download=True,transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4915,0.4823,0.4468),(0.2470,0.2435,0.2616))]))
img_t,_ = transforms_cifar10[99]
# PyTorch 中图像张量的格式是 (通道, 高度, 宽度),而 matplotlib 显示需要 (高度, 宽度, 通道)。所以需要:
# permute(1, 2, 0):把维度顺序从 (0,1,2) 变为 (1,2,0)
plt.imshow(img_t.permute(1,2,0))
plt.show()
![[衡天云]爆款云服务器 低至12元/月](/hty.png)