Update
This commit is contained in:
@@ -7,6 +7,7 @@ import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib as mpl
|
||||
import os
|
||||
from pathlib import Path
|
||||
import json
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
@@ -176,6 +177,8 @@ train_loss = []
|
||||
test_loss = []
|
||||
epoch_all_loss = []
|
||||
|
||||
print('Training.')
|
||||
|
||||
for epoch in range(epochs):
|
||||
|
||||
train_one_epoch_loss = []
|
||||
@@ -219,14 +222,27 @@ for epoch in range(epochs):
|
||||
|
||||
print(f"Epoch: {epoch+1}/{epochs} - Training loss: {np.mean(train_one_epoch_loss):.4f} - Test loss: {np.mean(test_one_epoch_loss):.4f}")
|
||||
|
||||
|
||||
output_dir = Path('./embedding-output')
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
torch.save(model.state_dict(), output_dir/'model_embedding.pt')
|
||||
|
||||
train_loss = np.array(train_loss)
|
||||
test_loss = np.array(test_loss)
|
||||
q = 10
|
||||
plt.plot(np.convolve(train_loss, np.ones(len(train_loss) - len(test_loss) + q)/(len(train_loss) - len(test_loss) + q), mode = 'valid'), legend = 'Train loss')
|
||||
plt.plot(np.convolve(test_loss, np.ones(q)/q, mode = 'valid'), legend = 'Test loss')
|
||||
|
||||
loss = {}
|
||||
loss["test_loss"] = list(test_loss.astype(float))
|
||||
loss["train_loss"] = list(train_loss.astype(float))
|
||||
with open(str(output_dir / 'loss.json'), 'w') as f:
|
||||
json.dump(loss, f, indent = 4)
|
||||
|
||||
q = 10 # plotting parameter
|
||||
plt.plot(np.convolve(train_loss, np.ones(len(train_loss) - len(test_loss) + q)/(len(train_loss) - len(test_loss) + q), mode = 'valid'), label = 'Train loss')
|
||||
plt.plot(np.convolve(test_loss, np.ones(q)/q, mode = 'valid'), label = 'Test loss')
|
||||
plt.legend()
|
||||
plt.title("Epoch loss")
|
||||
plt.savefig("embedding_loss.pdf")
|
||||
plt.savefig(output_dir/"embedding_loss.pdf")
|
||||
|
||||
if config["visualize"]:
|
||||
|
||||
@@ -261,5 +277,6 @@ if config["visualize"]:
|
||||
for i in range(len(classes)):
|
||||
legend.legendHandles[i]._sizes = [30]
|
||||
|
||||
plt.savefig('visualized_simple.pdf')
|
||||
plt.axis('off')
|
||||
plt.savefig(output_dir/'visualized_simple.pdf')
|
||||
|
||||
|
||||
Reference in New Issue
Block a user