Compare commits
5 Commits
6a8c963dd9
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
| bae1b00bdc | |||
| c6794e7f03 | |||
| 9b469860fa | |||
| c81d9106d5 | |||
| 68812a209b |
BIN
.model_inference.py.swp
Normal file
BIN
.model_inference.py.swp
Normal file
Binary file not shown.
@@ -2,5 +2,5 @@
|
||||
"n_neighbors": 10,
|
||||
"classes" : ["black_red", "green_orange", "yellow_grey"],
|
||||
"path_to_dataset": "./triplet_dataset",
|
||||
"embedding_model": "./model_embedding.pt"
|
||||
"embedding_model": "./embedding-output/model_embedding.pt"
|
||||
}
|
||||
|
||||
4
download-test-videos
Executable file
4
download-test-videos
Executable file
@@ -0,0 +1,4 @@
|
||||
#!/bin/bash
|
||||
gdown 1Mm24z7fe1fkbcTt05IQpLlNdSALJQFRc
|
||||
unzip -q test_videos_2022.zip
|
||||
rm test_videos_2022.zip
|
||||
@@ -1,5 +1,5 @@
|
||||
{
|
||||
"epochs": 4,
|
||||
"epochs": 1,
|
||||
"embedding_dims": 128,
|
||||
"batch_size": 32,
|
||||
"classes" : ["black_red", "green_orange", "yellow_grey"],
|
||||
|
||||
611
example.ipynb
Normal file
611
example.ipynb
Normal file
@@ -0,0 +1,611 @@
|
||||
{
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"provenance": [],
|
||||
"collapsed_sections": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"name": "python3",
|
||||
"display_name": "Python 3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
},
|
||||
"accelerator": "GPU",
|
||||
"gpuClass": "standard"
|
||||
},
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "DbeYatBLu8pf",
|
||||
"outputId": "4ed57e67-ae61-4662-e0cc-faed1c4062e4"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"Cloning into 'triplet-loss-cars'...\n",
|
||||
"remote: Enumerating objects: 17560, done.\u001b[K\n",
|
||||
"remote: Counting objects: 100% (17560/17560), done.\u001b[K\n",
|
||||
"remote: Compressing objects: 100% (9229/9229), done.\u001b[K\n",
|
||||
"remote: Total 17560 (delta 8336), reused 17531 (delta 8326)\u001b[K\n",
|
||||
"Receiving objects: 100% (17560/17560), 773.26 MiB | 16.83 MiB/s, done.\n",
|
||||
"Resolving deltas: 100% (8336/8336), done.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"!git clone https://git.drivecast.tech/vd/triplet-loss-cars.git"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"%cd triplet-loss-cars/"
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "264RI2RHu_u4",
|
||||
"outputId": "5cea8550-3c8f-453f-ee09-de205e117769"
|
||||
},
|
||||
"execution_count": 2,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"/content/triplet-loss-cars\n"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"!./download-dataset"
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "lboCMkqrvQtO",
|
||||
"outputId": "015dad3e-fda4-4daf-a86c-54bf37253fa4"
|
||||
},
|
||||
"execution_count": 3,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"Downloading...\n",
|
||||
"From: https://drive.google.com/uc?id=1rP7GHDqx6BKTGTh9I6ecEmRgn5-HG1N0\n",
|
||||
"To: /content/triplet-loss-cars/triplet_dataset.zip\n",
|
||||
"100% 27.3M/27.3M [00:01<00:00, 20.6MB/s]\n"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [],
|
||||
"metadata": {
|
||||
"id": "Jx07-q67wGfD"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"!python3 train-embedding.py"
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "iat8B1qNv96T",
|
||||
"outputId": "712601bf-9a2a-44d9-8d27-6dbc82a941f1"
|
||||
},
|
||||
"execution_count": 7,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"Using GPU.\n",
|
||||
"Training.\n",
|
||||
"Epoch: 1/10 - Training loss: 0.2827 - Test loss: 0.1240\n",
|
||||
"Epoch: 2/10 - Training loss: 0.1167 - Test loss: 0.1096\n",
|
||||
"Epoch: 3/10 - Training loss: 0.0887 - Test loss: 0.0779\n",
|
||||
"Epoch: 4/10 - Training loss: 0.0789 - Test loss: 0.0882\n",
|
||||
"Epoch: 5/10 - Training loss: 0.0718 - Test loss: 0.0383\n",
|
||||
"Epoch: 6/10 - Training loss: 0.0643 - Test loss: 0.0566\n",
|
||||
"^C\n"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"!python3 train-classifier.py"
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "XhCmibyKwC-U",
|
||||
"outputId": "e4b28c5a-d878-40d5-aaaf-f256e1ab2fa7"
|
||||
},
|
||||
"execution_count": 8,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"Using GPU.\n",
|
||||
"Score: 0.9918\n"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"!./download-test-videos"
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "9PRr_7pV4qwE",
|
||||
"outputId": "a71579b1-992b-41db-a275-79e5e55b5715"
|
||||
},
|
||||
"execution_count": 10,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"Downloading...\n",
|
||||
"From: https://drive.google.com/uc?id=1Mm24z7fe1fkbcTt05IQpLlNdSALJQFRc\n",
|
||||
"To: /content/triplet-loss-cars/test_videos_2022.zip\n",
|
||||
"100% 212M/212M [00:03<00:00, 63.2MB/s]\n"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"!unzip -q ./yolov5.zip"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "wZjPyaNo6XAa"
|
||||
},
|
||||
"execution_count": 11,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"!python3 model_inference.py"
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "EW8pO4f86fBj",
|
||||
"outputId": "15c057b3-c4df-43cb-9a6a-4201404b226e"
|
||||
},
|
||||
"execution_count": 12,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"Using GPU.\n",
|
||||
"Fusing layers... \n",
|
||||
"custom_YOLOv5l summary: 290 layers, 20852934 parameters, 0 gradients\n",
|
||||
"WARNING: --img-size [1920, 1080] must be multiple of max stride 32, updating to [1920, 1088]\n",
|
||||
"0\n",
|
||||
"Downloading https://ultralytics.com/assets/Arial.ttf to /root/.config/Ultralytics/Arial.ttf...\n",
|
||||
"1\n",
|
||||
"2\n",
|
||||
"3\n",
|
||||
"4\n",
|
||||
"5\n",
|
||||
"6\n",
|
||||
"7\n",
|
||||
"8\n",
|
||||
"9\n",
|
||||
"10\n",
|
||||
"11\n",
|
||||
"12\n",
|
||||
"13\n",
|
||||
"14\n",
|
||||
"15\n",
|
||||
"16\n",
|
||||
"17\n",
|
||||
"18\n",
|
||||
"19\n",
|
||||
"20\n",
|
||||
"21\n",
|
||||
"22\n",
|
||||
"23\n",
|
||||
"24\n",
|
||||
"25\n",
|
||||
"26\n",
|
||||
"27\n",
|
||||
"28\n",
|
||||
"29\n",
|
||||
"30\n",
|
||||
"31\n",
|
||||
"32\n",
|
||||
"33\n",
|
||||
"34\n",
|
||||
"35\n",
|
||||
"36\n",
|
||||
"37\n",
|
||||
"38\n",
|
||||
"39\n",
|
||||
"40\n",
|
||||
"41\n",
|
||||
"42\n",
|
||||
"43\n",
|
||||
"44\n",
|
||||
"45\n",
|
||||
"46\n",
|
||||
"47\n",
|
||||
"48\n",
|
||||
"49\n",
|
||||
"50\n",
|
||||
"51\n",
|
||||
"52\n",
|
||||
"53\n",
|
||||
"54\n",
|
||||
"55\n",
|
||||
"56\n",
|
||||
"57\n",
|
||||
"58\n",
|
||||
"59\n",
|
||||
"60\n",
|
||||
"61\n",
|
||||
"62\n",
|
||||
"63\n",
|
||||
"64\n",
|
||||
"65\n",
|
||||
"66\n",
|
||||
"67\n",
|
||||
"68\n",
|
||||
"69\n",
|
||||
"70\n",
|
||||
"71\n",
|
||||
"72\n",
|
||||
"73\n",
|
||||
"74\n",
|
||||
"75\n",
|
||||
"76\n",
|
||||
"77\n",
|
||||
"78\n",
|
||||
"79\n",
|
||||
"80\n",
|
||||
"81\n",
|
||||
"82\n",
|
||||
"83\n",
|
||||
"84\n",
|
||||
"85\n",
|
||||
"86\n",
|
||||
"87\n",
|
||||
"88\n",
|
||||
"89\n",
|
||||
"90\n",
|
||||
"91\n",
|
||||
"92\n",
|
||||
"93\n",
|
||||
"94\n",
|
||||
"95\n",
|
||||
"96\n",
|
||||
"97\n",
|
||||
"98\n",
|
||||
"99\n",
|
||||
"100\n",
|
||||
"101\n",
|
||||
"102\n",
|
||||
"103\n",
|
||||
"104\n",
|
||||
"105\n",
|
||||
"106\n",
|
||||
"107\n",
|
||||
"108\n",
|
||||
"109\n",
|
||||
"110\n",
|
||||
"111\n",
|
||||
"112\n",
|
||||
"113\n",
|
||||
"114\n",
|
||||
"115\n",
|
||||
"116\n",
|
||||
"117\n",
|
||||
"118\n",
|
||||
"119\n",
|
||||
"120\n",
|
||||
"121\n",
|
||||
"122\n",
|
||||
"123\n",
|
||||
"124\n",
|
||||
"125\n",
|
||||
"126\n",
|
||||
"127\n",
|
||||
"128\n",
|
||||
"129\n",
|
||||
"130\n",
|
||||
"131\n",
|
||||
"132\n",
|
||||
"133\n",
|
||||
"134\n",
|
||||
"135\n",
|
||||
"136\n",
|
||||
"137\n",
|
||||
"138\n",
|
||||
"139\n",
|
||||
"140\n",
|
||||
"141\n",
|
||||
"142\n",
|
||||
"143\n",
|
||||
"144\n",
|
||||
"145\n",
|
||||
"146\n",
|
||||
"147\n",
|
||||
"148\n",
|
||||
"149\n",
|
||||
"150\n",
|
||||
"151\n",
|
||||
"152\n",
|
||||
"153\n",
|
||||
"154\n",
|
||||
"155\n",
|
||||
"156\n",
|
||||
"157\n",
|
||||
"158\n",
|
||||
"159\n",
|
||||
"160\n",
|
||||
"161\n",
|
||||
"162\n",
|
||||
"163\n",
|
||||
"164\n",
|
||||
"165\n",
|
||||
"166\n",
|
||||
"167\n",
|
||||
"168\n",
|
||||
"169\n",
|
||||
"170\n",
|
||||
"171\n",
|
||||
"172\n",
|
||||
"173\n",
|
||||
"174\n",
|
||||
"175\n",
|
||||
"176\n",
|
||||
"177\n",
|
||||
"178\n",
|
||||
"179\n",
|
||||
"180\n",
|
||||
"181\n",
|
||||
"182\n",
|
||||
"183\n",
|
||||
"184\n",
|
||||
"185\n",
|
||||
"186\n",
|
||||
"187\n",
|
||||
"188\n",
|
||||
"189\n",
|
||||
"190\n",
|
||||
"191\n",
|
||||
"192\n",
|
||||
"193\n",
|
||||
"194\n",
|
||||
"195\n",
|
||||
"196\n",
|
||||
"197\n",
|
||||
"198\n",
|
||||
"199\n",
|
||||
"200\n",
|
||||
"201\n",
|
||||
"202\n",
|
||||
"203\n",
|
||||
"204\n",
|
||||
"205\n",
|
||||
"206\n",
|
||||
"207\n",
|
||||
"208\n",
|
||||
"209\n",
|
||||
"210\n",
|
||||
"211\n",
|
||||
"212\n",
|
||||
"213\n",
|
||||
"214\n",
|
||||
"215\n",
|
||||
"216\n",
|
||||
"217\n",
|
||||
"218\n",
|
||||
"219\n",
|
||||
"220\n",
|
||||
"221\n",
|
||||
"222\n",
|
||||
"223\n",
|
||||
"224\n",
|
||||
"225\n",
|
||||
"226\n",
|
||||
"227\n",
|
||||
"228\n",
|
||||
"229\n",
|
||||
"230\n",
|
||||
"231\n",
|
||||
"232\n",
|
||||
"233\n",
|
||||
"234\n",
|
||||
"235\n",
|
||||
"236\n",
|
||||
"237\n",
|
||||
"238\n",
|
||||
"239\n",
|
||||
"240\n",
|
||||
"241\n",
|
||||
"242\n",
|
||||
"243\n",
|
||||
"244\n",
|
||||
"245\n",
|
||||
"246\n",
|
||||
"247\n",
|
||||
"248\n",
|
||||
"249\n",
|
||||
"250\n",
|
||||
"251\n",
|
||||
"252\n",
|
||||
"253\n",
|
||||
"254\n",
|
||||
"255\n",
|
||||
"256\n",
|
||||
"257\n",
|
||||
"258\n",
|
||||
"259\n",
|
||||
"260\n",
|
||||
"261\n",
|
||||
"262\n",
|
||||
"263\n",
|
||||
"264\n",
|
||||
"265\n",
|
||||
"266\n",
|
||||
"267\n",
|
||||
"268\n",
|
||||
"269\n",
|
||||
"270\n",
|
||||
"271\n",
|
||||
"272\n",
|
||||
"273\n",
|
||||
"274\n",
|
||||
"275\n",
|
||||
"276\n",
|
||||
"277\n",
|
||||
"278\n",
|
||||
"279\n",
|
||||
"280\n",
|
||||
"281\n",
|
||||
"282\n",
|
||||
"283\n",
|
||||
"284\n",
|
||||
"285\n",
|
||||
"286\n",
|
||||
"287\n",
|
||||
"288\n",
|
||||
"289\n",
|
||||
"290\n",
|
||||
"291\n",
|
||||
"292\n",
|
||||
"293\n",
|
||||
"294\n",
|
||||
"295\n",
|
||||
"296\n",
|
||||
"297\n",
|
||||
"298\n",
|
||||
"299\n",
|
||||
"300\n",
|
||||
"301\n",
|
||||
"302\n",
|
||||
"303\n",
|
||||
"304\n",
|
||||
"305\n",
|
||||
"306\n",
|
||||
"307\n",
|
||||
"308\n",
|
||||
"309\n",
|
||||
"310\n",
|
||||
"311\n",
|
||||
"312\n",
|
||||
"313\n",
|
||||
"314\n",
|
||||
"315\n",
|
||||
"316\n",
|
||||
"317\n",
|
||||
"318\n",
|
||||
"319\n",
|
||||
"320\n",
|
||||
"321\n",
|
||||
"322\n",
|
||||
"323\n",
|
||||
"324\n",
|
||||
"325\n",
|
||||
"326\n",
|
||||
"327\n",
|
||||
"328\n",
|
||||
"329\n",
|
||||
"330\n",
|
||||
"331\n",
|
||||
"332\n",
|
||||
"333\n",
|
||||
"334\n",
|
||||
"335\n",
|
||||
"336\n",
|
||||
"337\n",
|
||||
"338\n",
|
||||
"339\n",
|
||||
"340\n",
|
||||
"341\n",
|
||||
"342\n",
|
||||
"343\n",
|
||||
"344\n",
|
||||
"345\n",
|
||||
"346\n",
|
||||
"347\n",
|
||||
"348\n",
|
||||
"349\n",
|
||||
"350\n",
|
||||
"351\n",
|
||||
"352\n",
|
||||
"353\n",
|
||||
"354\n",
|
||||
"355\n",
|
||||
"356\n",
|
||||
"357\n",
|
||||
"358\n",
|
||||
"359\n",
|
||||
"360\n",
|
||||
"361\n",
|
||||
"362\n",
|
||||
"363\n",
|
||||
"364\n",
|
||||
"365\n",
|
||||
"366\n",
|
||||
"367\n",
|
||||
"368\n",
|
||||
"369\n",
|
||||
"370\n",
|
||||
"371\n",
|
||||
"372\n",
|
||||
"373\n",
|
||||
"374\n"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [],
|
||||
"metadata": {
|
||||
"id": "jVEQTWoW6l_9"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
}
|
||||
]
|
||||
}
|
||||
BIN
model_classifier.obj
Normal file
BIN
model_classifier.obj
Normal file
Binary file not shown.
177
model_inference.py
Normal file
177
model_inference.py
Normal file
@@ -0,0 +1,177 @@
|
||||
import pickle
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.backends.cudnn as cudnn
|
||||
import random
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from tqdm.notebook import tqdm
|
||||
import matplotlib.pyplot as plt
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torchvision.io import read_image
|
||||
from torchvision import transforms
|
||||
import cv2
|
||||
|
||||
if './yolov5/' not in sys.path:
|
||||
sys.path.append('./yolov5/')
|
||||
|
||||
from models.common import DetectMultiBackend
|
||||
from utils.augmentations import letterbox
|
||||
from utils.general import scale_coords, non_max_suppression, check_img_size
|
||||
from utils.dataloaders import LoadImages
|
||||
from utils.plots import Annotator, colors, save_one_box
|
||||
|
||||
class EmbeddingModel(nn.Module):
|
||||
def __init__(self, emb_dim=128):
|
||||
super(EmbeddingModel, self).__init__()
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(3, 16, 3),
|
||||
nn.BatchNorm2d(16),
|
||||
nn.PReLU(),
|
||||
nn.MaxPool2d(2),
|
||||
|
||||
nn.Conv2d(16, 32, 3),
|
||||
nn.BatchNorm2d(32),
|
||||
nn.PReLU(32),
|
||||
nn.MaxPool2d(2),
|
||||
|
||||
nn.Conv2d(32, 64, 3),
|
||||
nn.PReLU(),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.MaxPool2d(2)
|
||||
)
|
||||
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(64*6*6, 256),
|
||||
nn.PReLU(),
|
||||
nn.Linear(256, emb_dim)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = x.view(-1, 64*6*6)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
class SquarePad:
|
||||
def __call__(self, image):
|
||||
_, w, h = image.size()
|
||||
max_wh = max(w, h)
|
||||
hp = int((max_wh - w) / 2)
|
||||
vp = int((max_wh - h) / 2)
|
||||
padding = (vp, hp, vp, hp)
|
||||
return transforms.functional.pad(image, padding, 0, 'constant')
|
||||
|
||||
class Normalize01:
|
||||
def __call__(self, image):
|
||||
image -= image.min()
|
||||
image /= image.max()
|
||||
return image
|
||||
|
||||
def prepare_for_embedding(image):
|
||||
image = torch.tensor(image).permute((2, 0, 1)).float().flip(0)
|
||||
transform = transforms.Compose([
|
||||
SquarePad(),
|
||||
transforms.Resize((64, 64)),
|
||||
Normalize01()
|
||||
])
|
||||
image = transform(image)
|
||||
return image.unsqueeze(0)
|
||||
|
||||
|
||||
if torch.cuda.is_available():
|
||||
print('Using GPU.')
|
||||
device = 'cuda'
|
||||
else:
|
||||
print("CUDA not detected, using CPU.")
|
||||
device = 'cpu'
|
||||
|
||||
model_embedding = EmbeddingModel()
|
||||
model_embedding.load_state_dict(torch.load('./embedding-output/model_embedding.pt'))
|
||||
model_embedding.to(device)
|
||||
model_embedding.eval()
|
||||
|
||||
with open('./model_classifier.obj','rb') as file:
|
||||
model_classifier = pickle.load(file)
|
||||
|
||||
classes = model_classifier.__getstate__()['classes_']
|
||||
|
||||
video = Path('./test_videos_2022/2022-NLS-5-NLS_05_2022_Heli_UHD_01-000140-000155-Karussell.mp4')
|
||||
|
||||
reader = cv2.VideoCapture(str(video))
|
||||
fps = reader.get(cv2.CAP_PROP_FPS)
|
||||
w = int(reader.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
h = int(reader.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
reader.release()
|
||||
|
||||
weights_path = Path('./yolov5/best.pt')
|
||||
model = DetectMultiBackend(weights_path, device=torch.device(device))
|
||||
|
||||
imgsz = check_img_size((w, h), s=model.stride)
|
||||
dataset = LoadImages(video, img_size=imgsz, stride=model.stride, auto=model.pt)
|
||||
|
||||
save_dir = Path('./detection-output/')
|
||||
os.makedirs(save_dir)
|
||||
|
||||
writer = cv2.VideoWriter(str(save_dir / 'res.mp4'), cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
|
||||
|
||||
for frame_n, (path, im, im0s, vid_cap, s) in enumerate(dataset):
|
||||
print(frame_n)
|
||||
im = torch.from_numpy(im).to(device)
|
||||
im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
|
||||
im /= 255 # 0 - 255 to 0.0 - 1.0
|
||||
im = im[None]
|
||||
pred = model(im)
|
||||
pred = non_max_suppression(pred, conf_thres = 0.5, max_det = 100)
|
||||
for i, det in enumerate(pred):
|
||||
p, im0, frame = path, im0s.copy(), getattr(dataset, 'frame', 0)
|
||||
imc = im0.copy()
|
||||
annotator = Annotator(imc, line_width=3, example=str(model.names), pil = True, font_size=20 )
|
||||
|
||||
if len(det) == 0:
|
||||
continue
|
||||
det[:, :4] = scale_coords(im.shape[2:], det[:, :4], im0.shape).round()
|
||||
|
||||
|
||||
for *xyxy, conf, cls in reversed(det):
|
||||
crop = save_one_box(xyxy, im0, file=save_dir / 'crops' / f'frame{frame_n}_{i}.jpg', BGR=True)
|
||||
image = prepare_for_embedding(crop).to(device)
|
||||
embedding = model_embedding(image).cpu().detach().numpy()
|
||||
|
||||
probabilities = model_classifier.predict_proba(embedding)[0]
|
||||
best = np.argmax(probabilities)
|
||||
|
||||
annotator.text([xyxy[0] -20, xyxy[1] - 20], classes[best])
|
||||
# print(classes[best])
|
||||
# print()
|
||||
|
||||
im0 = annotator.result().copy()
|
||||
writer.write(im0)
|
||||
|
||||
writer.release()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
gdown
|
||||
seaborn
|
||||
torch
|
||||
torchvision
|
||||
sklearn
|
||||
|
||||
@@ -7,6 +7,7 @@ import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib as mpl
|
||||
import os
|
||||
from pathlib import Path
|
||||
import json
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
@@ -176,6 +177,8 @@ train_loss = []
|
||||
test_loss = []
|
||||
epoch_all_loss = []
|
||||
|
||||
print('Training.')
|
||||
|
||||
for epoch in range(epochs):
|
||||
|
||||
train_one_epoch_loss = []
|
||||
@@ -219,14 +222,27 @@ for epoch in range(epochs):
|
||||
|
||||
print(f"Epoch: {epoch+1}/{epochs} - Training loss: {np.mean(train_one_epoch_loss):.4f} - Test loss: {np.mean(test_one_epoch_loss):.4f}")
|
||||
|
||||
|
||||
output_dir = Path('./embedding-output')
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
torch.save(model.state_dict(), output_dir/'model_embedding.pt')
|
||||
|
||||
train_loss = np.array(train_loss)
|
||||
test_loss = np.array(test_loss)
|
||||
q = 10
|
||||
plt.plot(np.convolve(train_loss, np.ones(len(train_loss) - len(test_loss) + q)/(len(train_loss) - len(test_loss) + q), mode = 'valid'), legend = 'Train loss')
|
||||
plt.plot(np.convolve(test_loss, np.ones(q)/q, mode = 'valid'), legend = 'Test loss')
|
||||
|
||||
loss = {}
|
||||
loss["test_loss"] = list(test_loss.astype(float))
|
||||
loss["train_loss"] = list(train_loss.astype(float))
|
||||
with open(str(output_dir / 'loss.json'), 'w') as f:
|
||||
json.dump(loss, f, indent = 4)
|
||||
|
||||
q = 10 # plotting parameter
|
||||
plt.plot(np.convolve(train_loss, np.ones(len(train_loss) - len(test_loss) + q)/(len(train_loss) - len(test_loss) + q), mode = 'valid'), label = 'Train loss')
|
||||
plt.plot(np.convolve(test_loss, np.ones(q)/q, mode = 'valid'), label = 'Test loss')
|
||||
plt.legend()
|
||||
plt.title("Epoch loss")
|
||||
plt.savefig("embedding_loss.pdf")
|
||||
plt.savefig(output_dir/"embedding_loss.pdf")
|
||||
|
||||
if config["visualize"]:
|
||||
|
||||
@@ -261,5 +277,6 @@ if config["visualize"]:
|
||||
for i in range(len(classes)):
|
||||
legend.legendHandles[i]._sizes = [30]
|
||||
|
||||
plt.savefig('visualized_simple.pdf')
|
||||
plt.axis('off')
|
||||
plt.savefig(output_dir/'visualized_simple.pdf')
|
||||
|
||||
|
||||
BIN
yolov5.zip
Normal file
BIN
yolov5.zip
Normal file
Binary file not shown.
Reference in New Issue
Block a user