This commit is contained in:
vd
2022-09-09 21:45:11 +03:00
parent 6a8c963dd9
commit 68812a209b
17504 changed files with 235 additions and 7 deletions

View File

@@ -7,6 +7,7 @@ import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import matplotlib as mpl
import os
from pathlib import Path
import json
from torch.utils.data import Dataset, DataLoader
@@ -176,6 +177,8 @@ train_loss = []
test_loss = []
epoch_all_loss = []
print('Training.')
for epoch in range(epochs):
train_one_epoch_loss = []
@@ -219,14 +222,27 @@ for epoch in range(epochs):
print(f"Epoch: {epoch+1}/{epochs} - Training loss: {np.mean(train_one_epoch_loss):.4f} - Test loss: {np.mean(test_one_epoch_loss):.4f}")
output_dir = Path('./embedding-output')
os.makedirs(output_dir, exist_ok=True)
torch.save(model.state_dict(), output_dir/'model_embedding.pt')
train_loss = np.array(train_loss)
test_loss = np.array(test_loss)
q = 10
plt.plot(np.convolve(train_loss, np.ones(len(train_loss) - len(test_loss) + q)/(len(train_loss) - len(test_loss) + q), mode = 'valid'), legend = 'Train loss')
plt.plot(np.convolve(test_loss, np.ones(q)/q, mode = 'valid'), legend = 'Test loss')
loss = {}
loss["test_loss"] = list(test_loss.astype(float))
loss["train_loss"] = list(train_loss.astype(float))
with open(str(output_dir / 'loss.json'), 'w') as f:
json.dump(loss, f, indent = 4)
q = 10 # plotting parameter
plt.plot(np.convolve(train_loss, np.ones(len(train_loss) - len(test_loss) + q)/(len(train_loss) - len(test_loss) + q), mode = 'valid'), label = 'Train loss')
plt.plot(np.convolve(test_loss, np.ones(q)/q, mode = 'valid'), label = 'Test loss')
plt.legend()
plt.title("Epoch loss")
plt.savefig("embedding_loss.pdf")
plt.savefig(output_dir/"embedding_loss.pdf")
if config["visualize"]:
@@ -261,5 +277,6 @@ if config["visualize"]:
for i in range(len(classes)):
legend.legendHandles[i]._sizes = [30]
plt.savefig('visualized_simple.pdf')
plt.axis('off')
plt.savefig(output_dir/'visualized_simple.pdf')