pytorch识别图片之003 图片类型判断
AI
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
from torchvision import datasets
from matplotlib import pyplot as plt
import numpy as np
import torch
from torchvision import transforms
data_path = '/Volumes/c/work/aigc'
transforms_cifar10 = datasets.CIFAR10(root=data_path,train=True,download=True,transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4915,0.4823,0.4468),(0.2470,0.2435,0.2616))]))
import torch.nn as nn
model= nn.Sequential(
nn.Linear(3*32*32,512),
nn.Tanh(),
nn.Linear(512,2), # 输出飞机和鸟的概率 2
nn.Softmax(dim=1)
)
softmax = nn.Softmax(dim=1) #放大大的 缩小小的
label_map={0:0,2:1}
class_names=['airplane','bird']
cifar2= [(img,label_map[label]) for img,label in transforms_cifar10 if label in [0,2]]
img,_=cifar2[0]
plt.imshow(img.permute(1,2,0))
plt.show()
# 数据形状的变化过程
# 原始图片: [3, 32, 32]
# ↓ view(-1)
# 展平后: [3072]
# ↓ unsqueeze(0)
# 最终: [1, 3072] ← 符合模型要求的输入格式
img_batch = img.view(-1).unsqueeze(0)
out = model(img_batch)
print(out)
![[衡天云]爆款云服务器 低至12元/月](/hty.png)