import time import torch import random import numpy as np import pandas as pd import torch.nn as nn import torch.optim as optim import matplotlib.pyplot as plt import matplotlib as mpl import os from pathlib import Path import json from torch.utils.data import Dataset, DataLoader from torchvision.io import read_image from torchvision import transforms from PIL import Image, ImageOps from sklearn.manifold import TSNE from sklearn.decomposition import PCA with open('embedding_config.json') as config_file: config = json.load(config_file) if config["env"] == "colab": from tqdm.notebook import tqdm elif config["env"] == "": from tqdm import tqdm path_to_dataset = Path(config["path_to_dataset"]) if len(config["classes"]) == 0: classes = sorted([x.name for x in path_to_dataset.iterdir() if x.is_dir()]) else: classes = config["classes"] class SquarePad: def __call__(self, image): _, w, h = image.size() max_wh = max(w, h) hp = int((max_wh - w) / 2) vp = int((max_wh - h) / 2) padding = (vp, hp, vp, hp) return transforms.functional.pad(image, padding, 0, 'constant') class Normalize01: def __call__(self, image): image -= image.min() image /= image.max() return image class CustomDataset(Dataset): def __init__(self, path, classes = None, augmentations=None, target_transform=None, size = (64, 64)): self.path = path # path to directories self.transform_aug = augmentations self.target_transform = target_transform self.size = size self.classes = classes self.paths_to_images = [] self.labels = [] for c in self.classes: paths_to_class = list(Path(self.path, c).glob('*.jpg')) self.labels += [c]*len(paths_to_class) self.paths_to_images += paths_to_class self.labels = np.array(self.labels) def __len__(self): return len(self.labels) def __getitem__(self, idx): anchor_label = self.labels[idx] positive_index_list = [i for i, x in enumerate(self.labels) if x == anchor_label] negative_index_list = [i for i, x in enumerate(self.labels) if x != anchor_label] positive_idx = np.random.choice(positive_index_list) negative_idx = np.random.choice(negative_index_list) images = [] for i in [idx, positive_idx, negative_idx]: images += [read_image(str(self.paths_to_images[i])).float()] transform = transforms.Compose([ SquarePad(), transforms.Resize(self.size), Normalize01() ]) for i, image in enumerate(images): images[i] = transform(image) if self.transform_aug is not None: for i, image in enumerate(images): images[i] = self.transform_aug(image) return images, anchor_label class EmbeddingModel(nn.Module): def __init__(self, emb_dim=128): super(EmbeddingModel, self).__init__() self.conv = nn.Sequential( nn.Conv2d(3, 16, 3), nn.BatchNorm2d(16), nn.PReLU(), nn.MaxPool2d(2), nn.Conv2d(16, 32, 3), nn.BatchNorm2d(32), nn.PReLU(32), nn.MaxPool2d(2), nn.Conv2d(32, 64, 3), nn.PReLU(), nn.BatchNorm2d(64), nn.MaxPool2d(2) ) self.fc = nn.Sequential( nn.Linear(64*6*6, 256), nn.PReLU(), nn.Linear(256, emb_dim) ) def forward(self, x): x = self.conv(x) x = x.view(-1, 64*6*6) x = self.fc(x) return x augmentations = None if config["augmentations"]: augmentations = transforms.Compose( [ transforms.RandomHorizontalFlip(), transforms.RandomAdjustSharpness(sharpness_factor=2), transforms.RandomAutocontrast(), transforms.ColorJitter(brightness=0.3) ] ) batch_size = config["batch_size"] dataset = CustomDataset(path_to_dataset, augmentations=augmentations, classes=classes, size=(64, 64)) train_size = int(config["train_size"] * len(dataset)) test_size = len(dataset) - train_size train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size]) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True) if torch.cuda.is_available(): print('Using GPU.') device = 'cuda' else: print("CUDA not detected, using CPU.") device = 'cpu' embedding_dims = config["embedding_dims"] epochs = config["epochs"] model = EmbeddingModel(embedding_dims).to(device) optimizer = optim.Adam(model.parameters(), lr=config["lr"]) triplet_loss = nn.TripletMarginLoss(margin=config["triplet-loss-margin"], p=config["triplet-loss-p"]) model.train() train_loss = [] test_loss = [] epoch_all_loss = [] print('Training.') for epoch in range(epochs): train_one_epoch_loss = [] test_one_epoch_loss = [] for step, ((anchor_image, positive_image, negative_image), anchor_label) in enumerate(train_loader): anchor_image = anchor_image.to(device) positive_image = positive_image.to(device) negative_image = negative_image.to(device) optimizer.zero_grad() anchor_pred = model(anchor_image) positive_pred = model(positive_image) negative_pred = model(negative_image) loss = triplet_loss(anchor_pred, positive_pred, negative_pred) loss.backward() optimizer.step() loss = loss.cpu().detach().numpy() train_one_epoch_loss += [loss] train_loss += [loss] for step, ((anchor_image, positive_image, negative_image), anchor_label) in enumerate(test_loader): anchor_image = anchor_image.to(device) positive_image = positive_image.to(device) negative_image = negative_image.to(device) anchor_pred = model(anchor_image) positive_pred = model(positive_image) negative_pred = model(negative_image) loss = triplet_loss(anchor_pred, positive_pred, negative_pred) loss = loss.cpu().detach().numpy() test_one_epoch_loss += [loss] test_loss += [loss] print(f"Epoch: {epoch+1}/{epochs} - Training loss: {np.mean(train_one_epoch_loss):.4f} - Test loss: {np.mean(test_one_epoch_loss):.4f}") output_dir = Path('./embedding-output') os.makedirs(output_dir, exist_ok=True) torch.save(model.state_dict(), output_dir/'model_embedding.pt') train_loss = np.array(train_loss) test_loss = np.array(test_loss) loss = {} loss["test_loss"] = list(test_loss.astype(float)) loss["train_loss"] = list(train_loss.astype(float)) with open(str(output_dir / 'loss.json'), 'w') as f: json.dump(loss, f, indent = 4) q = 10 # plotting parameter plt.plot(np.convolve(train_loss, np.ones(len(train_loss) - len(test_loss) + q)/(len(train_loss) - len(test_loss) + q), mode = 'valid'), label = 'Train loss') plt.plot(np.convolve(test_loss, np.ones(q)/q, mode = 'valid'), label = 'Test loss') plt.legend() plt.title("Epoch loss") plt.savefig(output_dir/"embedding_loss.pdf") if config["visualize"]: X_train = [] labels_train = [] images_train = [] for step, ((batch, _, _), label) in enumerate(train_loader): X_train += [*model(batch.to(device)).cpu().detach().numpy()] labels_train += [*label] for x in batch: x -= x.min() x /= x.max() images_train += [transforms.functional.to_pil_image(x)] X_train = np.array(X_train) labels_train = np.array(labels_train) pca = PCA(n_components=config["pca_components"]) X_train_pca = pca.fit_transform(X_train) tsne = TSNE(n_components=2, learning_rate = 'auto', init = 'pca') X_train_tsne = tsne.fit_transform(X_train_pca) cmap=plt.get_cmap('tab20') colors = cmap(np.linspace(0, 1, len(classes))) classes_to_colors = dict(zip(classes, colors)) plt.figure(figsize=(15, 10)) for cls in classes: plt.scatter(*X_train_tsne[labels_train == cls].T, color = classes_to_colors[cls], label = cls, s = 1); legend = plt.legend(fontsize=10) for i in range(len(classes)): legend.legendHandles[i]._sizes = [30] plt.axis('off') plt.savefig(output_dir/'visualized_simple.pdf')