Compare commits
5 Commits
6a8c963dd9
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
| bae1b00bdc | |||
| c6794e7f03 | |||
| 9b469860fa | |||
| c81d9106d5 | |||
| 68812a209b |
BIN
.model_inference.py.swp
Normal file
BIN
.model_inference.py.swp
Normal file
Binary file not shown.
@@ -2,5 +2,5 @@
|
|||||||
"n_neighbors": 10,
|
"n_neighbors": 10,
|
||||||
"classes" : ["black_red", "green_orange", "yellow_grey"],
|
"classes" : ["black_red", "green_orange", "yellow_grey"],
|
||||||
"path_to_dataset": "./triplet_dataset",
|
"path_to_dataset": "./triplet_dataset",
|
||||||
"embedding_model": "./model_embedding.pt"
|
"embedding_model": "./embedding-output/model_embedding.pt"
|
||||||
}
|
}
|
||||||
|
|||||||
4
download-test-videos
Executable file
4
download-test-videos
Executable file
@@ -0,0 +1,4 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
gdown 1Mm24z7fe1fkbcTt05IQpLlNdSALJQFRc
|
||||||
|
unzip -q test_videos_2022.zip
|
||||||
|
rm test_videos_2022.zip
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"epochs": 4,
|
"epochs": 1,
|
||||||
"embedding_dims": 128,
|
"embedding_dims": 128,
|
||||||
"batch_size": 32,
|
"batch_size": 32,
|
||||||
"classes" : ["black_red", "green_orange", "yellow_grey"],
|
"classes" : ["black_red", "green_orange", "yellow_grey"],
|
||||||
|
|||||||
611
example.ipynb
Normal file
611
example.ipynb
Normal file
@@ -0,0 +1,611 @@
|
|||||||
|
{
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 0,
|
||||||
|
"metadata": {
|
||||||
|
"colab": {
|
||||||
|
"provenance": [],
|
||||||
|
"collapsed_sections": []
|
||||||
|
},
|
||||||
|
"kernelspec": {
|
||||||
|
"name": "python3",
|
||||||
|
"display_name": "Python 3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"name": "python"
|
||||||
|
},
|
||||||
|
"accelerator": "GPU",
|
||||||
|
"gpuClass": "standard"
|
||||||
|
},
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"metadata": {
|
||||||
|
"colab": {
|
||||||
|
"base_uri": "https://localhost:8080/"
|
||||||
|
},
|
||||||
|
"id": "DbeYatBLu8pf",
|
||||||
|
"outputId": "4ed57e67-ae61-4662-e0cc-faed1c4062e4"
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"output_type": "stream",
|
||||||
|
"name": "stdout",
|
||||||
|
"text": [
|
||||||
|
"Cloning into 'triplet-loss-cars'...\n",
|
||||||
|
"remote: Enumerating objects: 17560, done.\u001b[K\n",
|
||||||
|
"remote: Counting objects: 100% (17560/17560), done.\u001b[K\n",
|
||||||
|
"remote: Compressing objects: 100% (9229/9229), done.\u001b[K\n",
|
||||||
|
"remote: Total 17560 (delta 8336), reused 17531 (delta 8326)\u001b[K\n",
|
||||||
|
"Receiving objects: 100% (17560/17560), 773.26 MiB | 16.83 MiB/s, done.\n",
|
||||||
|
"Resolving deltas: 100% (8336/8336), done.\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"!git clone https://git.drivecast.tech/vd/triplet-loss-cars.git"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"source": [
|
||||||
|
"%cd triplet-loss-cars/"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"colab": {
|
||||||
|
"base_uri": "https://localhost:8080/"
|
||||||
|
},
|
||||||
|
"id": "264RI2RHu_u4",
|
||||||
|
"outputId": "5cea8550-3c8f-453f-ee09-de205e117769"
|
||||||
|
},
|
||||||
|
"execution_count": 2,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"output_type": "stream",
|
||||||
|
"name": "stdout",
|
||||||
|
"text": [
|
||||||
|
"/content/triplet-loss-cars\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"source": [
|
||||||
|
"!./download-dataset"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"colab": {
|
||||||
|
"base_uri": "https://localhost:8080/"
|
||||||
|
},
|
||||||
|
"id": "lboCMkqrvQtO",
|
||||||
|
"outputId": "015dad3e-fda4-4daf-a86c-54bf37253fa4"
|
||||||
|
},
|
||||||
|
"execution_count": 3,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"output_type": "stream",
|
||||||
|
"name": "stdout",
|
||||||
|
"text": [
|
||||||
|
"Downloading...\n",
|
||||||
|
"From: https://drive.google.com/uc?id=1rP7GHDqx6BKTGTh9I6ecEmRgn5-HG1N0\n",
|
||||||
|
"To: /content/triplet-loss-cars/triplet_dataset.zip\n",
|
||||||
|
"100% 27.3M/27.3M [00:01<00:00, 20.6MB/s]\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"source": [],
|
||||||
|
"metadata": {
|
||||||
|
"id": "Jx07-q67wGfD"
|
||||||
|
},
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"source": [
|
||||||
|
"!python3 train-embedding.py"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"colab": {
|
||||||
|
"base_uri": "https://localhost:8080/"
|
||||||
|
},
|
||||||
|
"id": "iat8B1qNv96T",
|
||||||
|
"outputId": "712601bf-9a2a-44d9-8d27-6dbc82a941f1"
|
||||||
|
},
|
||||||
|
"execution_count": 7,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"output_type": "stream",
|
||||||
|
"name": "stdout",
|
||||||
|
"text": [
|
||||||
|
"Using GPU.\n",
|
||||||
|
"Training.\n",
|
||||||
|
"Epoch: 1/10 - Training loss: 0.2827 - Test loss: 0.1240\n",
|
||||||
|
"Epoch: 2/10 - Training loss: 0.1167 - Test loss: 0.1096\n",
|
||||||
|
"Epoch: 3/10 - Training loss: 0.0887 - Test loss: 0.0779\n",
|
||||||
|
"Epoch: 4/10 - Training loss: 0.0789 - Test loss: 0.0882\n",
|
||||||
|
"Epoch: 5/10 - Training loss: 0.0718 - Test loss: 0.0383\n",
|
||||||
|
"Epoch: 6/10 - Training loss: 0.0643 - Test loss: 0.0566\n",
|
||||||
|
"^C\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"source": [
|
||||||
|
"!python3 train-classifier.py"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"colab": {
|
||||||
|
"base_uri": "https://localhost:8080/"
|
||||||
|
},
|
||||||
|
"id": "XhCmibyKwC-U",
|
||||||
|
"outputId": "e4b28c5a-d878-40d5-aaaf-f256e1ab2fa7"
|
||||||
|
},
|
||||||
|
"execution_count": 8,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"output_type": "stream",
|
||||||
|
"name": "stdout",
|
||||||
|
"text": [
|
||||||
|
"Using GPU.\n",
|
||||||
|
"Score: 0.9918\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"source": [
|
||||||
|
"!./download-test-videos"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"colab": {
|
||||||
|
"base_uri": "https://localhost:8080/"
|
||||||
|
},
|
||||||
|
"id": "9PRr_7pV4qwE",
|
||||||
|
"outputId": "a71579b1-992b-41db-a275-79e5e55b5715"
|
||||||
|
},
|
||||||
|
"execution_count": 10,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"output_type": "stream",
|
||||||
|
"name": "stdout",
|
||||||
|
"text": [
|
||||||
|
"Downloading...\n",
|
||||||
|
"From: https://drive.google.com/uc?id=1Mm24z7fe1fkbcTt05IQpLlNdSALJQFRc\n",
|
||||||
|
"To: /content/triplet-loss-cars/test_videos_2022.zip\n",
|
||||||
|
"100% 212M/212M [00:03<00:00, 63.2MB/s]\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"source": [
|
||||||
|
"!unzip -q ./yolov5.zip"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"id": "wZjPyaNo6XAa"
|
||||||
|
},
|
||||||
|
"execution_count": 11,
|
||||||
|
"outputs": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"source": [
|
||||||
|
"!python3 model_inference.py"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"colab": {
|
||||||
|
"base_uri": "https://localhost:8080/"
|
||||||
|
},
|
||||||
|
"id": "EW8pO4f86fBj",
|
||||||
|
"outputId": "15c057b3-c4df-43cb-9a6a-4201404b226e"
|
||||||
|
},
|
||||||
|
"execution_count": 12,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"output_type": "stream",
|
||||||
|
"name": "stdout",
|
||||||
|
"text": [
|
||||||
|
"Using GPU.\n",
|
||||||
|
"Fusing layers... \n",
|
||||||
|
"custom_YOLOv5l summary: 290 layers, 20852934 parameters, 0 gradients\n",
|
||||||
|
"WARNING: --img-size [1920, 1080] must be multiple of max stride 32, updating to [1920, 1088]\n",
|
||||||
|
"0\n",
|
||||||
|
"Downloading https://ultralytics.com/assets/Arial.ttf to /root/.config/Ultralytics/Arial.ttf...\n",
|
||||||
|
"1\n",
|
||||||
|
"2\n",
|
||||||
|
"3\n",
|
||||||
|
"4\n",
|
||||||
|
"5\n",
|
||||||
|
"6\n",
|
||||||
|
"7\n",
|
||||||
|
"8\n",
|
||||||
|
"9\n",
|
||||||
|
"10\n",
|
||||||
|
"11\n",
|
||||||
|
"12\n",
|
||||||
|
"13\n",
|
||||||
|
"14\n",
|
||||||
|
"15\n",
|
||||||
|
"16\n",
|
||||||
|
"17\n",
|
||||||
|
"18\n",
|
||||||
|
"19\n",
|
||||||
|
"20\n",
|
||||||
|
"21\n",
|
||||||
|
"22\n",
|
||||||
|
"23\n",
|
||||||
|
"24\n",
|
||||||
|
"25\n",
|
||||||
|
"26\n",
|
||||||
|
"27\n",
|
||||||
|
"28\n",
|
||||||
|
"29\n",
|
||||||
|
"30\n",
|
||||||
|
"31\n",
|
||||||
|
"32\n",
|
||||||
|
"33\n",
|
||||||
|
"34\n",
|
||||||
|
"35\n",
|
||||||
|
"36\n",
|
||||||
|
"37\n",
|
||||||
|
"38\n",
|
||||||
|
"39\n",
|
||||||
|
"40\n",
|
||||||
|
"41\n",
|
||||||
|
"42\n",
|
||||||
|
"43\n",
|
||||||
|
"44\n",
|
||||||
|
"45\n",
|
||||||
|
"46\n",
|
||||||
|
"47\n",
|
||||||
|
"48\n",
|
||||||
|
"49\n",
|
||||||
|
"50\n",
|
||||||
|
"51\n",
|
||||||
|
"52\n",
|
||||||
|
"53\n",
|
||||||
|
"54\n",
|
||||||
|
"55\n",
|
||||||
|
"56\n",
|
||||||
|
"57\n",
|
||||||
|
"58\n",
|
||||||
|
"59\n",
|
||||||
|
"60\n",
|
||||||
|
"61\n",
|
||||||
|
"62\n",
|
||||||
|
"63\n",
|
||||||
|
"64\n",
|
||||||
|
"65\n",
|
||||||
|
"66\n",
|
||||||
|
"67\n",
|
||||||
|
"68\n",
|
||||||
|
"69\n",
|
||||||
|
"70\n",
|
||||||
|
"71\n",
|
||||||
|
"72\n",
|
||||||
|
"73\n",
|
||||||
|
"74\n",
|
||||||
|
"75\n",
|
||||||
|
"76\n",
|
||||||
|
"77\n",
|
||||||
|
"78\n",
|
||||||
|
"79\n",
|
||||||
|
"80\n",
|
||||||
|
"81\n",
|
||||||
|
"82\n",
|
||||||
|
"83\n",
|
||||||
|
"84\n",
|
||||||
|
"85\n",
|
||||||
|
"86\n",
|
||||||
|
"87\n",
|
||||||
|
"88\n",
|
||||||
|
"89\n",
|
||||||
|
"90\n",
|
||||||
|
"91\n",
|
||||||
|
"92\n",
|
||||||
|
"93\n",
|
||||||
|
"94\n",
|
||||||
|
"95\n",
|
||||||
|
"96\n",
|
||||||
|
"97\n",
|
||||||
|
"98\n",
|
||||||
|
"99\n",
|
||||||
|
"100\n",
|
||||||
|
"101\n",
|
||||||
|
"102\n",
|
||||||
|
"103\n",
|
||||||
|
"104\n",
|
||||||
|
"105\n",
|
||||||
|
"106\n",
|
||||||
|
"107\n",
|
||||||
|
"108\n",
|
||||||
|
"109\n",
|
||||||
|
"110\n",
|
||||||
|
"111\n",
|
||||||
|
"112\n",
|
||||||
|
"113\n",
|
||||||
|
"114\n",
|
||||||
|
"115\n",
|
||||||
|
"116\n",
|
||||||
|
"117\n",
|
||||||
|
"118\n",
|
||||||
|
"119\n",
|
||||||
|
"120\n",
|
||||||
|
"121\n",
|
||||||
|
"122\n",
|
||||||
|
"123\n",
|
||||||
|
"124\n",
|
||||||
|
"125\n",
|
||||||
|
"126\n",
|
||||||
|
"127\n",
|
||||||
|
"128\n",
|
||||||
|
"129\n",
|
||||||
|
"130\n",
|
||||||
|
"131\n",
|
||||||
|
"132\n",
|
||||||
|
"133\n",
|
||||||
|
"134\n",
|
||||||
|
"135\n",
|
||||||
|
"136\n",
|
||||||
|
"137\n",
|
||||||
|
"138\n",
|
||||||
|
"139\n",
|
||||||
|
"140\n",
|
||||||
|
"141\n",
|
||||||
|
"142\n",
|
||||||
|
"143\n",
|
||||||
|
"144\n",
|
||||||
|
"145\n",
|
||||||
|
"146\n",
|
||||||
|
"147\n",
|
||||||
|
"148\n",
|
||||||
|
"149\n",
|
||||||
|
"150\n",
|
||||||
|
"151\n",
|
||||||
|
"152\n",
|
||||||
|
"153\n",
|
||||||
|
"154\n",
|
||||||
|
"155\n",
|
||||||
|
"156\n",
|
||||||
|
"157\n",
|
||||||
|
"158\n",
|
||||||
|
"159\n",
|
||||||
|
"160\n",
|
||||||
|
"161\n",
|
||||||
|
"162\n",
|
||||||
|
"163\n",
|
||||||
|
"164\n",
|
||||||
|
"165\n",
|
||||||
|
"166\n",
|
||||||
|
"167\n",
|
||||||
|
"168\n",
|
||||||
|
"169\n",
|
||||||
|
"170\n",
|
||||||
|
"171\n",
|
||||||
|
"172\n",
|
||||||
|
"173\n",
|
||||||
|
"174\n",
|
||||||
|
"175\n",
|
||||||
|
"176\n",
|
||||||
|
"177\n",
|
||||||
|
"178\n",
|
||||||
|
"179\n",
|
||||||
|
"180\n",
|
||||||
|
"181\n",
|
||||||
|
"182\n",
|
||||||
|
"183\n",
|
||||||
|
"184\n",
|
||||||
|
"185\n",
|
||||||
|
"186\n",
|
||||||
|
"187\n",
|
||||||
|
"188\n",
|
||||||
|
"189\n",
|
||||||
|
"190\n",
|
||||||
|
"191\n",
|
||||||
|
"192\n",
|
||||||
|
"193\n",
|
||||||
|
"194\n",
|
||||||
|
"195\n",
|
||||||
|
"196\n",
|
||||||
|
"197\n",
|
||||||
|
"198\n",
|
||||||
|
"199\n",
|
||||||
|
"200\n",
|
||||||
|
"201\n",
|
||||||
|
"202\n",
|
||||||
|
"203\n",
|
||||||
|
"204\n",
|
||||||
|
"205\n",
|
||||||
|
"206\n",
|
||||||
|
"207\n",
|
||||||
|
"208\n",
|
||||||
|
"209\n",
|
||||||
|
"210\n",
|
||||||
|
"211\n",
|
||||||
|
"212\n",
|
||||||
|
"213\n",
|
||||||
|
"214\n",
|
||||||
|
"215\n",
|
||||||
|
"216\n",
|
||||||
|
"217\n",
|
||||||
|
"218\n",
|
||||||
|
"219\n",
|
||||||
|
"220\n",
|
||||||
|
"221\n",
|
||||||
|
"222\n",
|
||||||
|
"223\n",
|
||||||
|
"224\n",
|
||||||
|
"225\n",
|
||||||
|
"226\n",
|
||||||
|
"227\n",
|
||||||
|
"228\n",
|
||||||
|
"229\n",
|
||||||
|
"230\n",
|
||||||
|
"231\n",
|
||||||
|
"232\n",
|
||||||
|
"233\n",
|
||||||
|
"234\n",
|
||||||
|
"235\n",
|
||||||
|
"236\n",
|
||||||
|
"237\n",
|
||||||
|
"238\n",
|
||||||
|
"239\n",
|
||||||
|
"240\n",
|
||||||
|
"241\n",
|
||||||
|
"242\n",
|
||||||
|
"243\n",
|
||||||
|
"244\n",
|
||||||
|
"245\n",
|
||||||
|
"246\n",
|
||||||
|
"247\n",
|
||||||
|
"248\n",
|
||||||
|
"249\n",
|
||||||
|
"250\n",
|
||||||
|
"251\n",
|
||||||
|
"252\n",
|
||||||
|
"253\n",
|
||||||
|
"254\n",
|
||||||
|
"255\n",
|
||||||
|
"256\n",
|
||||||
|
"257\n",
|
||||||
|
"258\n",
|
||||||
|
"259\n",
|
||||||
|
"260\n",
|
||||||
|
"261\n",
|
||||||
|
"262\n",
|
||||||
|
"263\n",
|
||||||
|
"264\n",
|
||||||
|
"265\n",
|
||||||
|
"266\n",
|
||||||
|
"267\n",
|
||||||
|
"268\n",
|
||||||
|
"269\n",
|
||||||
|
"270\n",
|
||||||
|
"271\n",
|
||||||
|
"272\n",
|
||||||
|
"273\n",
|
||||||
|
"274\n",
|
||||||
|
"275\n",
|
||||||
|
"276\n",
|
||||||
|
"277\n",
|
||||||
|
"278\n",
|
||||||
|
"279\n",
|
||||||
|
"280\n",
|
||||||
|
"281\n",
|
||||||
|
"282\n",
|
||||||
|
"283\n",
|
||||||
|
"284\n",
|
||||||
|
"285\n",
|
||||||
|
"286\n",
|
||||||
|
"287\n",
|
||||||
|
"288\n",
|
||||||
|
"289\n",
|
||||||
|
"290\n",
|
||||||
|
"291\n",
|
||||||
|
"292\n",
|
||||||
|
"293\n",
|
||||||
|
"294\n",
|
||||||
|
"295\n",
|
||||||
|
"296\n",
|
||||||
|
"297\n",
|
||||||
|
"298\n",
|
||||||
|
"299\n",
|
||||||
|
"300\n",
|
||||||
|
"301\n",
|
||||||
|
"302\n",
|
||||||
|
"303\n",
|
||||||
|
"304\n",
|
||||||
|
"305\n",
|
||||||
|
"306\n",
|
||||||
|
"307\n",
|
||||||
|
"308\n",
|
||||||
|
"309\n",
|
||||||
|
"310\n",
|
||||||
|
"311\n",
|
||||||
|
"312\n",
|
||||||
|
"313\n",
|
||||||
|
"314\n",
|
||||||
|
"315\n",
|
||||||
|
"316\n",
|
||||||
|
"317\n",
|
||||||
|
"318\n",
|
||||||
|
"319\n",
|
||||||
|
"320\n",
|
||||||
|
"321\n",
|
||||||
|
"322\n",
|
||||||
|
"323\n",
|
||||||
|
"324\n",
|
||||||
|
"325\n",
|
||||||
|
"326\n",
|
||||||
|
"327\n",
|
||||||
|
"328\n",
|
||||||
|
"329\n",
|
||||||
|
"330\n",
|
||||||
|
"331\n",
|
||||||
|
"332\n",
|
||||||
|
"333\n",
|
||||||
|
"334\n",
|
||||||
|
"335\n",
|
||||||
|
"336\n",
|
||||||
|
"337\n",
|
||||||
|
"338\n",
|
||||||
|
"339\n",
|
||||||
|
"340\n",
|
||||||
|
"341\n",
|
||||||
|
"342\n",
|
||||||
|
"343\n",
|
||||||
|
"344\n",
|
||||||
|
"345\n",
|
||||||
|
"346\n",
|
||||||
|
"347\n",
|
||||||
|
"348\n",
|
||||||
|
"349\n",
|
||||||
|
"350\n",
|
||||||
|
"351\n",
|
||||||
|
"352\n",
|
||||||
|
"353\n",
|
||||||
|
"354\n",
|
||||||
|
"355\n",
|
||||||
|
"356\n",
|
||||||
|
"357\n",
|
||||||
|
"358\n",
|
||||||
|
"359\n",
|
||||||
|
"360\n",
|
||||||
|
"361\n",
|
||||||
|
"362\n",
|
||||||
|
"363\n",
|
||||||
|
"364\n",
|
||||||
|
"365\n",
|
||||||
|
"366\n",
|
||||||
|
"367\n",
|
||||||
|
"368\n",
|
||||||
|
"369\n",
|
||||||
|
"370\n",
|
||||||
|
"371\n",
|
||||||
|
"372\n",
|
||||||
|
"373\n",
|
||||||
|
"374\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"source": [],
|
||||||
|
"metadata": {
|
||||||
|
"id": "jVEQTWoW6l_9"
|
||||||
|
},
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": []
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
BIN
model_classifier.obj
Normal file
BIN
model_classifier.obj
Normal file
Binary file not shown.
177
model_inference.py
Normal file
177
model_inference.py
Normal file
@@ -0,0 +1,177 @@
|
|||||||
|
import pickle
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.backends.cudnn as cudnn
|
||||||
|
import random
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
from tqdm.notebook import tqdm
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
from torchvision.io import read_image
|
||||||
|
from torchvision import transforms
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
if './yolov5/' not in sys.path:
|
||||||
|
sys.path.append('./yolov5/')
|
||||||
|
|
||||||
|
from models.common import DetectMultiBackend
|
||||||
|
from utils.augmentations import letterbox
|
||||||
|
from utils.general import scale_coords, non_max_suppression, check_img_size
|
||||||
|
from utils.dataloaders import LoadImages
|
||||||
|
from utils.plots import Annotator, colors, save_one_box
|
||||||
|
|
||||||
|
class EmbeddingModel(nn.Module):
|
||||||
|
def __init__(self, emb_dim=128):
|
||||||
|
super(EmbeddingModel, self).__init__()
|
||||||
|
self.conv = nn.Sequential(
|
||||||
|
nn.Conv2d(3, 16, 3),
|
||||||
|
nn.BatchNorm2d(16),
|
||||||
|
nn.PReLU(),
|
||||||
|
nn.MaxPool2d(2),
|
||||||
|
|
||||||
|
nn.Conv2d(16, 32, 3),
|
||||||
|
nn.BatchNorm2d(32),
|
||||||
|
nn.PReLU(32),
|
||||||
|
nn.MaxPool2d(2),
|
||||||
|
|
||||||
|
nn.Conv2d(32, 64, 3),
|
||||||
|
nn.PReLU(),
|
||||||
|
nn.BatchNorm2d(64),
|
||||||
|
nn.MaxPool2d(2)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.fc = nn.Sequential(
|
||||||
|
nn.Linear(64*6*6, 256),
|
||||||
|
nn.PReLU(),
|
||||||
|
nn.Linear(256, emb_dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
x = x.view(-1, 64*6*6)
|
||||||
|
x = self.fc(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class SquarePad:
|
||||||
|
def __call__(self, image):
|
||||||
|
_, w, h = image.size()
|
||||||
|
max_wh = max(w, h)
|
||||||
|
hp = int((max_wh - w) / 2)
|
||||||
|
vp = int((max_wh - h) / 2)
|
||||||
|
padding = (vp, hp, vp, hp)
|
||||||
|
return transforms.functional.pad(image, padding, 0, 'constant')
|
||||||
|
|
||||||
|
class Normalize01:
|
||||||
|
def __call__(self, image):
|
||||||
|
image -= image.min()
|
||||||
|
image /= image.max()
|
||||||
|
return image
|
||||||
|
|
||||||
|
def prepare_for_embedding(image):
|
||||||
|
image = torch.tensor(image).permute((2, 0, 1)).float().flip(0)
|
||||||
|
transform = transforms.Compose([
|
||||||
|
SquarePad(),
|
||||||
|
transforms.Resize((64, 64)),
|
||||||
|
Normalize01()
|
||||||
|
])
|
||||||
|
image = transform(image)
|
||||||
|
return image.unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
print('Using GPU.')
|
||||||
|
device = 'cuda'
|
||||||
|
else:
|
||||||
|
print("CUDA not detected, using CPU.")
|
||||||
|
device = 'cpu'
|
||||||
|
|
||||||
|
model_embedding = EmbeddingModel()
|
||||||
|
model_embedding.load_state_dict(torch.load('./embedding-output/model_embedding.pt'))
|
||||||
|
model_embedding.to(device)
|
||||||
|
model_embedding.eval()
|
||||||
|
|
||||||
|
with open('./model_classifier.obj','rb') as file:
|
||||||
|
model_classifier = pickle.load(file)
|
||||||
|
|
||||||
|
classes = model_classifier.__getstate__()['classes_']
|
||||||
|
|
||||||
|
video = Path('./test_videos_2022/2022-NLS-5-NLS_05_2022_Heli_UHD_01-000140-000155-Karussell.mp4')
|
||||||
|
|
||||||
|
reader = cv2.VideoCapture(str(video))
|
||||||
|
fps = reader.get(cv2.CAP_PROP_FPS)
|
||||||
|
w = int(reader.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||||
|
h = int(reader.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||||
|
reader.release()
|
||||||
|
|
||||||
|
weights_path = Path('./yolov5/best.pt')
|
||||||
|
model = DetectMultiBackend(weights_path, device=torch.device(device))
|
||||||
|
|
||||||
|
imgsz = check_img_size((w, h), s=model.stride)
|
||||||
|
dataset = LoadImages(video, img_size=imgsz, stride=model.stride, auto=model.pt)
|
||||||
|
|
||||||
|
save_dir = Path('./detection-output/')
|
||||||
|
os.makedirs(save_dir)
|
||||||
|
|
||||||
|
writer = cv2.VideoWriter(str(save_dir / 'res.mp4'), cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
|
||||||
|
|
||||||
|
for frame_n, (path, im, im0s, vid_cap, s) in enumerate(dataset):
|
||||||
|
print(frame_n)
|
||||||
|
im = torch.from_numpy(im).to(device)
|
||||||
|
im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
|
||||||
|
im /= 255 # 0 - 255 to 0.0 - 1.0
|
||||||
|
im = im[None]
|
||||||
|
pred = model(im)
|
||||||
|
pred = non_max_suppression(pred, conf_thres = 0.5, max_det = 100)
|
||||||
|
for i, det in enumerate(pred):
|
||||||
|
p, im0, frame = path, im0s.copy(), getattr(dataset, 'frame', 0)
|
||||||
|
imc = im0.copy()
|
||||||
|
annotator = Annotator(imc, line_width=3, example=str(model.names), pil = True, font_size=20 )
|
||||||
|
|
||||||
|
if len(det) == 0:
|
||||||
|
continue
|
||||||
|
det[:, :4] = scale_coords(im.shape[2:], det[:, :4], im0.shape).round()
|
||||||
|
|
||||||
|
|
||||||
|
for *xyxy, conf, cls in reversed(det):
|
||||||
|
crop = save_one_box(xyxy, im0, file=save_dir / 'crops' / f'frame{frame_n}_{i}.jpg', BGR=True)
|
||||||
|
image = prepare_for_embedding(crop).to(device)
|
||||||
|
embedding = model_embedding(image).cpu().detach().numpy()
|
||||||
|
|
||||||
|
probabilities = model_classifier.predict_proba(embedding)[0]
|
||||||
|
best = np.argmax(probabilities)
|
||||||
|
|
||||||
|
annotator.text([xyxy[0] -20, xyxy[1] - 20], classes[best])
|
||||||
|
# print(classes[best])
|
||||||
|
# print()
|
||||||
|
|
||||||
|
im0 = annotator.result().copy()
|
||||||
|
writer.write(im0)
|
||||||
|
|
||||||
|
writer.release()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
gdown
|
gdown
|
||||||
|
seaborn
|
||||||
torch
|
torch
|
||||||
torchvision
|
torchvision
|
||||||
sklearn
|
sklearn
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import torch.nn as nn
|
|||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import matplotlib as mpl
|
import matplotlib as mpl
|
||||||
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import json
|
import json
|
||||||
from torch.utils.data import Dataset, DataLoader
|
from torch.utils.data import Dataset, DataLoader
|
||||||
@@ -176,6 +177,8 @@ train_loss = []
|
|||||||
test_loss = []
|
test_loss = []
|
||||||
epoch_all_loss = []
|
epoch_all_loss = []
|
||||||
|
|
||||||
|
print('Training.')
|
||||||
|
|
||||||
for epoch in range(epochs):
|
for epoch in range(epochs):
|
||||||
|
|
||||||
train_one_epoch_loss = []
|
train_one_epoch_loss = []
|
||||||
@@ -219,14 +222,27 @@ for epoch in range(epochs):
|
|||||||
|
|
||||||
print(f"Epoch: {epoch+1}/{epochs} - Training loss: {np.mean(train_one_epoch_loss):.4f} - Test loss: {np.mean(test_one_epoch_loss):.4f}")
|
print(f"Epoch: {epoch+1}/{epochs} - Training loss: {np.mean(train_one_epoch_loss):.4f} - Test loss: {np.mean(test_one_epoch_loss):.4f}")
|
||||||
|
|
||||||
|
|
||||||
|
output_dir = Path('./embedding-output')
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
torch.save(model.state_dict(), output_dir/'model_embedding.pt')
|
||||||
|
|
||||||
train_loss = np.array(train_loss)
|
train_loss = np.array(train_loss)
|
||||||
test_loss = np.array(test_loss)
|
test_loss = np.array(test_loss)
|
||||||
q = 10
|
|
||||||
plt.plot(np.convolve(train_loss, np.ones(len(train_loss) - len(test_loss) + q)/(len(train_loss) - len(test_loss) + q), mode = 'valid'), legend = 'Train loss')
|
loss = {}
|
||||||
plt.plot(np.convolve(test_loss, np.ones(q)/q, mode = 'valid'), legend = 'Test loss')
|
loss["test_loss"] = list(test_loss.astype(float))
|
||||||
|
loss["train_loss"] = list(train_loss.astype(float))
|
||||||
|
with open(str(output_dir / 'loss.json'), 'w') as f:
|
||||||
|
json.dump(loss, f, indent = 4)
|
||||||
|
|
||||||
|
q = 10 # plotting parameter
|
||||||
|
plt.plot(np.convolve(train_loss, np.ones(len(train_loss) - len(test_loss) + q)/(len(train_loss) - len(test_loss) + q), mode = 'valid'), label = 'Train loss')
|
||||||
|
plt.plot(np.convolve(test_loss, np.ones(q)/q, mode = 'valid'), label = 'Test loss')
|
||||||
plt.legend()
|
plt.legend()
|
||||||
plt.title("Epoch loss")
|
plt.title("Epoch loss")
|
||||||
plt.savefig("embedding_loss.pdf")
|
plt.savefig(output_dir/"embedding_loss.pdf")
|
||||||
|
|
||||||
if config["visualize"]:
|
if config["visualize"]:
|
||||||
|
|
||||||
@@ -261,5 +277,6 @@ if config["visualize"]:
|
|||||||
for i in range(len(classes)):
|
for i in range(len(classes)):
|
||||||
legend.legendHandles[i]._sizes = [30]
|
legend.legendHandles[i]._sizes = [30]
|
||||||
|
|
||||||
plt.savefig('visualized_simple.pdf')
|
plt.axis('off')
|
||||||
|
plt.savefig(output_dir/'visualized_simple.pdf')
|
||||||
|
|
||||||
|
|||||||
BIN
yolov5.zip
Normal file
BIN
yolov5.zip
Normal file
Binary file not shown.
Reference in New Issue
Block a user