物体检测-垃圾分类
云中有鹿 2020/4/30
import os
import torch
from tqdm import tqdm
import math
import torchvision
from PIL import Image,ImageDraw,ImageFont
from torch import autograd
import torchvision.transforms as T
# 单独加载模型
CKP_PATH = './fasterrcnn_resnet50_fpn_coco-258fb6c6.pth'
Weight_PATH = "./TrainedNet1.pt"
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
# 定义FasterRCNN的网络结,主要是修改预测的类别数量
def get_model(num_classes):
# load an instance segmentation model pre-trained pre-trained on COCO
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(
pretrained=False, pretrained_backbone=False
)
model.load_state_dict(torch.load(CKP_PATH))
# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features
# replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
return model
# 加载模型权重
def Load_model():
model = get_model(num_classes=205)
if os.path.exists(Weight_PATH):
model.load_state_dict(torch.load(Weight_PATH,map_location='cpu'))
model.eval()
return model
# 加载模型文件
model = Load_model()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 直接加载模型
model = torch.load('my_model.pth',map_location='cpu')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 图片处理
def image_process(Image_url):
img = Image.open(Image_url)
transform = T.Compose([T.ToTensor()]) # Defing PyTorch Transform
img_tensor = transform(img) # Apply the transform to the image
return img_tensor
# 生成验证集的结果,并检查验证集上的预测效果
def boxes_to_lines(preds,json_dict):
r = []
for bbox, label, score in zip(
preds[0]["boxes"].cpu().detach().numpy(),
preds[0]["labels"].cpu().detach().numpy(),
preds[0]["scores"].cpu().detach().numpy(),
):
# torchvision生成的bounding box格式为xyxy,需要转成xywh
xyxy = list(bbox)
xywh = [xyxy[0], xyxy[1], xyxy[2] - xyxy[0], xyxy[3]- xyxy[1]]
if score >0.4 :
r.append(
{
"bbox": xywh,
"category_id": json_dict[str(label)],
"score": score,
}
)
return r
#导入加载分类文件
import json
with open('./gabage-class.json', "r", encoding='utf-8') as fp:
json_dict = json.load(fp)
#预测试结果显示
test_show_img_url ='../train/2020-gaebage/50345826a11f.JPG'
img_tensor = image_process(test_show_img_url)
preds = model.forward([img_tensor])
result= boxes_to_lines(preds,json_dict)
font = ImageFont.truetype('simsun.ttc',29)
base = Image.open(test_show_img_url).convert('RGBA')
d = ImageDraw.Draw(base)
for item in result:
d.text((int(item['bbox'][0]),int(item['bbox'][1])),item['category_id'],(255,255,0),font=font) #分类标签
d.rectangle([int(item['bbox'][0]),int(item['bbox'][1]),int(item['bbox'][0])+int(item['bbox'][2]), int(item['bbox'][1])+int(item['bbox'][3])],outline='RED') # ,加入fill="red"的话,就可以填充颜色
base
result
[{'bbox': [650.15063, 204.76646, 197.27887, 238.59322],
'category_id': '西瓜皮_湿垃圾',
'score': 0.9993536},
{'bbox': [984.9458, 506.62936, 99.37793, 102.61508],
'category_id': '橡皮泥_干垃圾',
'score': 0.99911386},
{'bbox': [1169.229, 486.35214, 153.3247, 142.97177],
'category_id': '粉笔_干垃圾',
'score': 0.9983192},
{'bbox': [1361.24, 487.92105, 33.58069, 30.368927],
'category_id': '药片_有害垃圾',
'score': 0.9971042},
{'bbox': [813.12195, 431.14136, 57.010986, 183.35211],
'category_id': '鸡骨头_湿垃圾',
'score': 0.99564517},
{'bbox': [969.70544, 180.80382, 125.880005, 297.3694],
'category_id': '玉米棒_湿垃圾',
'score': 0.99445397},
{'bbox': [923.8991, 548.4976, 89.91388, 92.52167],
'category_id': '动物内脏_湿垃圾',
'score': 0.9941591},
{'bbox': [1093.0338, 208.35583, 198.45679, 227.7938],
'category_id': '粽子_湿垃圾',
'score': 0.99379784},
{'bbox': [815.61865, 312.36667, 157.42194, 92.74896],
'category_id': '金属工具_可回收垃圾',
'score': 0.9934223},
{'bbox': [1307.6252, 366.43896, 101.814575, 120.547455],
'category_id': '金属工具_可回收垃圾',
'score': 0.9883195},
{'bbox': [452.42352, 393.40768, 335.7898, 580.0679],
'category_id': '毛发_干垃圾',
'score': 0.967124},
{'bbox': [1219.7949, 581.2517, 227.08667, 204.16669],
'category_id': '农药瓶_有害垃圾',
'score': 0.944504},
{'bbox': [476.8133, 471.62048, 361.78912, 73.64337],
'category_id': '口红_干垃圾',
'score': 0.81554884},
{'bbox': [935.1674, 183.82547, 79.83429, 97.20517],
'category_id': '榴莲壳_干垃圾',
'score': 0.69249374},
{'bbox': [479.96027, 454.2975, 369.42932, 88.35443],
'category_id': '笔_干垃圾',
'score': 0.50012475}]
评论区