164 lines
4.5 KiB
Python
164 lines
4.5 KiB
Python
import torch
|
|
import random
|
|
import numpy as np
|
|
import pandas as pd
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from tqdm.notebook import tqdm
|
|
import matplotlib.pyplot as plt
|
|
from pathlib import Path
|
|
from torch.utils.data import Dataset, DataLoader
|
|
from torchvision.io import read_image
|
|
from torchvision import transforms
|
|
from sklearn.neighbors import KNeighborsClassifier as kNN
|
|
import pickle
|
|
import json
|
|
|
|
with open('classifier_config.json') as config_file:
|
|
config = json.load(config_file)
|
|
|
|
path_to_dataset = Path(config["path_to_dataset"])
|
|
|
|
if len(config["classes"]) == 0:
|
|
classes = sorted([x.name for x in path_to_dataset.iterdir() if x.is_dir()])
|
|
else:
|
|
classes = config["classes"]
|
|
|
|
class SquarePad:
|
|
def __call__(self, image):
|
|
_, w, h = image.size()
|
|
max_wh = max(w, h)
|
|
hp = int((max_wh - w) / 2)
|
|
vp = int((max_wh - h) / 2)
|
|
padding = (vp, hp, vp, hp)
|
|
return transforms.functional.pad(image, padding, 0, 'constant')
|
|
|
|
class Normalize01:
|
|
def __call__(self, image):
|
|
image -= image.min()
|
|
image /= image.max()
|
|
return image
|
|
|
|
class CustomDataset(Dataset):
|
|
def __init__(self, path, classes = None, augmentations=None, target_transform=None, size = (64, 64)):
|
|
self.path = path # path to directories
|
|
|
|
self.transform_aug = augmentations
|
|
self.target_transform = target_transform
|
|
self.size = size
|
|
|
|
self.classes = classes
|
|
|
|
self.paths_to_images = []
|
|
self.labels = []
|
|
|
|
for c in self.classes:
|
|
paths_to_class = list(Path(self.path, c).glob('*.jpg'))
|
|
self.labels += [c]*len(paths_to_class)
|
|
self.paths_to_images += paths_to_class
|
|
self.labels = np.array(self.labels)
|
|
|
|
def __len__(self):
|
|
return len(self.labels)
|
|
|
|
def __getitem__(self, idx):
|
|
anchor_label = self.labels[idx]
|
|
|
|
positive_index_list = [i for i, x in enumerate(self.labels) if x == anchor_label]
|
|
negative_index_list = [i for i, x in enumerate(self.labels) if x != anchor_label]
|
|
|
|
positive_idx = np.random.choice(positive_index_list)
|
|
negative_idx = np.random.choice(negative_index_list)
|
|
|
|
images = []
|
|
for i in [idx, positive_idx, negative_idx]:
|
|
images += [read_image(str(self.paths_to_images[i])).float()]
|
|
|
|
transform = transforms.Compose([
|
|
SquarePad(),
|
|
transforms.Resize(self.size),
|
|
Normalize01()
|
|
])
|
|
|
|
for i, image in enumerate(images):
|
|
images[i] = transform(image)
|
|
|
|
if self.transform_aug is not None:
|
|
for i, image in enumerate(images):
|
|
images[i] = self.transform_aug(image)
|
|
|
|
return images, anchor_label
|
|
|
|
batch_size = 1
|
|
dataset = CustomDataset(path_to_dataset, classes=classes, size=(64, 64))
|
|
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
|
|
|
class EmbeddingModel(nn.Module):
|
|
def __init__(self, emb_dim=128):
|
|
super(EmbeddingModel, self).__init__()
|
|
self.conv = nn.Sequential(
|
|
nn.Conv2d(3, 16, 3),
|
|
nn.BatchNorm2d(16),
|
|
nn.PReLU(),
|
|
nn.MaxPool2d(2),
|
|
|
|
nn.Conv2d(16, 32, 3),
|
|
nn.BatchNorm2d(32),
|
|
nn.PReLU(32),
|
|
nn.MaxPool2d(2),
|
|
|
|
nn.Conv2d(32, 64, 3),
|
|
nn.PReLU(),
|
|
nn.BatchNorm2d(64),
|
|
nn.MaxPool2d(2)
|
|
)
|
|
|
|
self.fc = nn.Sequential(
|
|
nn.Linear(64*6*6, 256),
|
|
nn.PReLU(),
|
|
nn.Linear(256, emb_dim)
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x = x.view(-1, 64*6*6)
|
|
x = self.fc(x)
|
|
return x
|
|
|
|
|
|
if torch.cuda.is_available():
|
|
print('Using GPU.')
|
|
device = 'cuda'
|
|
else:
|
|
print("CUDA not detected, using CPU.")
|
|
device = 'cpu'
|
|
|
|
model_embedding = EmbeddingModel()
|
|
|
|
model_embedding.load_state_dict(torch.load(config["embedding_model"]))
|
|
model_embedding.to(device)
|
|
model_embedding.eval()
|
|
|
|
X = []
|
|
labels = []
|
|
#images = []
|
|
for step, ((batch, _, _), label) in enumerate(loader):
|
|
X += [*model_embedding(batch.to(device)).cpu().detach().numpy()]
|
|
labels += [*label]
|
|
for x in batch:
|
|
x -= x.min()
|
|
x /= x.max()
|
|
#images += [transforms.functional.to_pil_image(x)]
|
|
|
|
X = np.array(X)
|
|
labels = np.array(labels)
|
|
|
|
model = kNN(config["n_neighbors"])
|
|
model.fit(X, labels)
|
|
|
|
with open(b'./model_classifier.obj', 'wb') as file:
|
|
pickle.dump(model, file)
|
|
|
|
score = model.score(X, labels)
|
|
print(f'Score: {score:.4f}')
|