PyTorch深度学习开发 使用预训练的ResNet网络给图片分类
AI

import ssl
ssl._create_default_https_context = ssl._create_unverified_context
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
# 1. 加载模型
model = models.resnet101(pretrained=True)
model.eval() # 设置为预测模式
# 3. 图片处理
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 4. 识别函数
def recognize_image(image_path):
# 打开图片
image = Image.open('/Volumes/c/work/aigc/1.png').convert('RGB') # 关键:转换为RGB
print(image)
# 处理图片
input_tensor = preprocess(image)
input_batch = torch.unsqueeze(input_tensor,0) # 增加一个维度
# 预测
with torch.no_grad():
output = model(input_batch)
# 获取概率
probabilities = torch.nn.functional.softmax(output[0], dim=0)
# 取前3个结果
top3_prob, top3_catid = torch.topk(probabilities, 3)
# 默认类型
from torchvision.models import ResNet101_Weights
weights = ResNet101_Weights.DEFAULT
imagenet_classes = weights.meta["categories"]
# 打印结果
print(f"\n识别结果: {image_path}")
print("-" * 30)
for i in range(top3_prob.size(0)):
cat_id = top3_catid[i].item()
prob = top3_prob[i].item() * 100
label = imagenet_classes[cat_id]
print(f"{i + 1}. {label} ({prob:.1f}%)")
print("-" * 30)
# 5. 使用示例
if __name__ == "__main__":
# 识别自己的图片(替换成你的图片路径)
image_path = "1.png" # 改为你的图片路径
try:
recognize_image(image_path)
except FileNotFoundError:
print(f"找不到图片: {image_path}")
except Exception as e:
print(f"出错: {e}")![[衡天云]爆款云服务器 低至12元/月](/hty.png)