Files
triplet-loss-cars/model_inference.py
2022-09-09 21:45:11 +03:00

177 lines
4.6 KiB
Python

import pickle
import os
import sys
from pathlib import Path
import torch
import torch.backends.cudnn as cudnn
import random
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.optim as optim
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_image
from torchvision import transforms
import cv2
if './yolov5/' not in sys.path:
sys.path.append('./yolov5/')
from models.common import DetectMultiBackend
from utils.augmentations import letterbox
from utils.general import scale_coords, non_max_suppression, check_img_size
from utils.dataloaders import LoadImages
from utils.plots import Annotator, colors, save_one_box
class EmbeddingModel(nn.Module):
def __init__(self, emb_dim=128):
super(EmbeddingModel, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(3, 16, 3),
nn.BatchNorm2d(16),
nn.PReLU(),
nn.MaxPool2d(2),
nn.Conv2d(16, 32, 3),
nn.BatchNorm2d(32),
nn.PReLU(32),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3),
nn.PReLU(),
nn.BatchNorm2d(64),
nn.MaxPool2d(2)
)
self.fc = nn.Sequential(
nn.Linear(64*6*6, 256),
nn.PReLU(),
nn.Linear(256, emb_dim)
)
def forward(self, x):
x = self.conv(x)
x = x.view(-1, 64*6*6)
x = self.fc(x)
return x
class SquarePad:
def __call__(self, image):
_, w, h = image.size()
max_wh = max(w, h)
hp = int((max_wh - w) / 2)
vp = int((max_wh - h) / 2)
padding = (vp, hp, vp, hp)
return transforms.functional.pad(image, padding, 0, 'constant')
class Normalize01:
def __call__(self, image):
image -= image.min()
image /= image.max()
return image
def prepare_for_embedding(image):
image = torch.tensor(image).permute((2, 0, 1)).float().flip(0)
transform = transforms.Compose([
SquarePad(),
transforms.Resize((64, 64)),
Normalize01()
])
image = transform(image)
return image.unsqueeze(0)
if torch.cuda.is_available():
print('Using GPU.')
device = 'cuda'
else:
print("CUDA not detected, using CPU.")
device = 'cpu'
model_embedding = EmbeddingModel()
model_embedding.load_state_dict(torch.load('./embedding-output/model_embedding.pt'))
model_embedding.to(device)
model_embedding.eval()
with open('/model_classifier.obj','rb') as file:
model_classifier = pickle.load(file)
classes = model_classifier.__getstate__()['classes_']
video = Path('/content/test_videos_2022/2022-NLS-5-NLS_05_2022_Heli_UHD_01-000140-000155-Karussell.mp4')
reader = cv2.VideoCapture(str(video))
fps = reader.get(cv2.CAP_PROP_FPS)
w = int(reader.get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(reader.get(cv2.CAP_PROP_FRAME_HEIGHT))
reader.release()
imgsz = check_img_size((w, h), s=model.stride)
dataset = LoadImages(video, img_size=imgsz, stride=model.stride, auto=model.pt)
weights_path = Path('./yolov5/best.pt')
model = DetectMultiBackend(weights_path, device=torch.device(device))
save_dir = Path('./detection-output/')
os.makedirs(save_dir)
writer = cv2.VideoWriter(str(save_dir / 'res.mp4'), cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
for frame_n, (path, im, im0s, vid_cap, s) in enumerate(dataset):
im = torch.from_numpy(im).to(device)
im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
im /= 255 # 0 - 255 to 0.0 - 1.0
im = im[None]
pred = model(im)
pred = non_max_suppression(pred, conf_thres = 0.5, max_det = 100)
for i, det in enumerate(pred):
p, im0, frame = path, im0s.copy(), getattr(dataset, 'frame', 0)
imc = im0.copy()
annotator = Annotator(imc, line_width=3, example=str(model.names), pil = True, font_size=20 )
if len(det) == 0:
continue
det[:, :4] = scale_coords(im.shape[2:], det[:, :4], im0.shape).round()
for *xyxy, conf, cls in reversed(det):
crop = save_one_box(xyxy, im0, file=save_dir / 'crops' / f'frame{frame_n}_{i}.jpg', BGR=True)
image = prepare_for_embedding(crop).to(device)
embedding = model_embedding(image).cpu().detach().numpy()
probabilities = model_classifier.predict_proba(embedding)[0]
best = np.argmax(probabilities)
annotator.text([xyxy[0] -20, xyxy[1] - 20], classes[best])
# print(classes[best])
# print()
im0 = annotator.result().copy()
writer.write(im0)
writer.release()