Files
place6D_Nurburgring/Karussell/Supervisely/SuperviselyKeypointsGUI/SuperviselyKeypointsGUI.py
2022-05-18 12:01:01 +03:00

250 lines
9.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import supervisely_lib as sly
import pandas as pd
import cv2 as cv
import os
# from PIL import Image
import numpy as np
# import open3d as o3d
import wx
import wx.xrc
def last_keypoints_on_img(ann_info):
updated = []
for obj in ann_info[2]['objects']:
updated.append([obj['classTitle'], obj['updatedAt']])
updated = pd.DataFrame(updated, columns=['classTitle', 'updatedAt'])
updated.updatedAt = pd.to_datetime(updated.updatedAt)
last = updated[updated.updatedAt == updated.updatedAt.max()]
return last
def label2hash(meta_json, last):
for clss in meta_json['classes']:
if clss['title'] == last['classTitle'].values[0]:
meta_nodes = clss['geometry_config']['nodes']
label2hash = {}
for name in meta_nodes:
label2hash[meta_nodes[name]['label']] = name
return label2hash
def fit(imageSize, keypoints_2d, keypoints_3d, focus=1):
objectPoints = keypoints_3d.loc[keypoints_2d.index].values
imagePoints = keypoints_2d[['x', 'y']].values.astype('float')
n = len(imagePoints)
fx = fy = focus*np.hypot(*imageSize)
cx = imageSize[1]/2
cy = imageSize[0]/2
distCoeffs = np.zeros(4, np.float32)
if n < 6:
raise ValueError('Number of keypoints must be > 5')
cameraMatrix = np.float32([[fx,0, cx],
[0, fy,cy],
[0, 0, 1]])
_, rvecs, tvecs = cv.solvePnP(objectPoints, imagePoints, cameraMatrix, distCoeffs, flags=cv.SOLVEPNP_ITERATIVE )
return rvecs, tvecs, cameraMatrix, distCoeffs
def draw_cloud(img, points_3d, params):
imgpts, _ = cv.projectPoints(points_3d, *params)
for p in imgpts[:, 0]:
img = cv.circle(img, p.astype(int), 0, (255,20,147), -1)
def draw_keypoints(img, keypoints_3d, params):
imgpts, _ = cv.projectPoints(keypoints_3d, *params)
for p in imgpts[:, 0]:
img = cv.circle(img, p.astype(int), 5, (0,0,0), -1)
def resize(img, width=1000):
y, x = img.shape[:2]
return cv.resize(img, (width, int(y/x*width)))
class Start_annotation():
def __init__( self, project_id, dataset_id, token,
local_dataset_path, keypoints_3d_path, point_cloud_path):
adress = 'https://app.supervise.ly/'
# self.local_dataset_path = local_dataset_path
# self.points_3d = np.asarray(o3d.io.read_point_cloud(point_cloud_path).points)
self.keypoints_3d = pd.read_csv(keypoints_3d_path, index_col=0).astype(float)
self.api = sly.Api(adress, token)
self.meta_json = self.api.project.get_meta(project_id)
self.meta = sly.ProjectMeta.from_json(self.meta_json)
self.images = pd.DataFrame(self.api.image.get_list(dataset_id)).sort_values('name', ignore_index=True)
def load_ann(self, img_id):
ann_info = self.api.annotation.download(img_id)
return ann_info
def annotations(self, ann_info):
last = last_keypoints_on_img(ann_info)
if len(last) == 0:
return
nodes = ann_info[2]['objects'][last.index[0]]['nodes']
keypoints_2d = pd.DataFrame(columns=['x', 'y'])
for i in range(1, len(nodes)+1):
keypoints_2d.loc[i] = nodes[label2hash(self.meta_json, last)[str(i)]]['loc']
return keypoints_2d
def new_annotations(self, ann_info, new_keypoints):
last = last_keypoints_on_img(ann_info)
nodes = ann_info[2]['objects'][last.index[0]]['nodes']
for i in new_keypoints.index:
nodes[label2hash(self.meta_json, last)[str(i)]]['loc'] = new_keypoints.loc[i].tolist()
return ann_info
def start(self):
app = wx.App()
wnd = GUI(self.images, self.transform_by_visible)
wnd.Show(True)
app.MainLoop()
def transform_by_visible(self, idxs, img_id, name, focus=1, send=True, all_points=False, change_all=False, plot=False):
if send==plot==False:
return 'Error_empty_request'
ann_info = self.load_ann(img_id)
keypoints_2d = self.annotations(ann_info)
if keypoints_2d is None:
return 'Error_annotations'
if not all_points:
keypoints_2d = keypoints_2d.loc[idxs]
imgSize = list(ann_info.annotation['size'].values())
params = fit(imgSize, keypoints_2d, self.keypoints_3d, focus)
new_keypoints = pd.DataFrame(cv.projectPoints(self.keypoints_3d.values, *params)[0][:, 0],
columns=['x', 'y'], index=range(1, len(self.keypoints_3d)+1))
if not change_all:
new_keypoints = new_keypoints.drop(idxs)
if send:
new_ann = self.new_annotations(ann_info, new_keypoints)
new_ann = sly.Annotation.from_json(new_ann.annotation, self.meta)
self.api.annotation.upload_ann(img_id, new_ann)
if plot:
img = cv.imread(os.path.join(self.local_dataset_path, name))
# for p in keypoints_2d.values:
# img = cv.circle(img, p.astype(int), 8, (255,255,255), -1)
draw_cloud(img, self.points_3d, params)
cv.imshow(name, resize(img, 1200))
a = cv.waitKey(0)
if (a==ord('q')) | (a==233):
cv.destroyAllWindows()
return len(keypoints_2d)
class GUI( wx.Frame ):
def __init__( self, images, func):
wx.Frame.__init__ ( self, None, id = wx.ID_ANY, title = 'SuperviselyKeypointsGui', pos = wx.DefaultPosition, size = wx.Size( 250,450 ), style = wx.CAPTION|wx.CLOSE_BOX|wx.SYSTEM_MENU|wx.RESIZE_BORDER|wx.TAB_TRAVERSAL )
self.images = images
self.func = func
self.SetSizeHints( wx.DefaultSize, wx.DefaultSize )
gbSizer1 = wx.GridBagSizer( 0, 0 )
gbSizer1.SetFlexibleDirection( wx.VERTICAL )
gbSizer1.SetNonFlexibleGrowMode( wx.FLEX_GROWMODE_SPECIFIED )
gbSizer1.SetMinSize( wx.Size( 200,400 ) )
bSizer3 = wx.BoxSizer( wx.VERTICAL )
self.m_checkBoxes = []
for i in range(25):
CheckBox = wx.CheckBox( self, wx.ID_ANY, u"{}".format(str(i+1)), wx.DefaultPosition, wx.DefaultSize, 0 )
self.m_checkBoxes.append(CheckBox)
bSizer3.Add( self.m_checkBoxes[i], 0, wx.ALL, 0 )
gbSizer1.Add( bSizer3, wx.GBPosition( 0, 0 ), wx.GBSpan( 1, 1 ), wx.EXPAND|wx.LEFT|wx.TOP, 5 )
wSizer5 = wx.WrapSizer( wx.HORIZONTAL, 0 )
m_radioBox5Choices = [ u"True", u"False" ]
self.radioBox_send = wx.RadioBox( self, wx.ID_ANY, u"SendKeypoints", wx.DefaultPosition, wx.DefaultSize, m_radioBox5Choices, 1, wx.RA_SPECIFY_COLS )
self.radioBox_send.SetSelection( 1 )
wSizer5.Add( self.radioBox_send, 1, wx.ALL, 5 )
m_radioBox7Choices = [ u"True", u"False" ]
self.radioBox_plot = wx.RadioBox( self, wx.ID_ANY, u"Plot", wx.DefaultPosition, wx.DefaultSize, m_radioBox7Choices, 1, wx.RA_SPECIFY_COLS )
self.radioBox_plot.SetSelection( 1 )
wSizer5.Add( self.radioBox_plot, 1, wx.ALL, 5 )
self.m_checkBox26 = wx.CheckBox(self, wx.ID_ANY, u"All points", wx.DefaultPosition, wx.DefaultSize, 0)
wSizer5.Add( self.m_checkBox26, 0, wx.ALL, 5 )
self.m_staticText7 = wx.StaticText( self, wx.ID_ANY, u"0", wx.DefaultPosition, wx.DefaultSize, 0 )
self.m_staticText7.Wrap( -1 )
wSizer5.Add( self.m_staticText7, 0, wx.ALL, 5 )
self.m_button3 = wx.Button( self, wx.ID_ANY, u"Go", wx.DefaultPosition, wx.DefaultSize, 0 )
wSizer5.Add( self.m_button3, 0, wx.ALL, 5 )
gbSizer1.Add( wSizer5, wx.GBPosition( 0, 1 ), wx.GBSpan( 1, 1 ), wx.ALIGN_CENTER_HORIZONTAL|wx.EXPAND|wx.TOP, 0 )
m_choice3Choices = self.images.name.tolist()
self.m_choice3 = wx.Choice( self, wx.ID_ANY, wx.DefaultPosition, wx.DefaultSize, m_choice3Choices, 0 )
self.m_choice3.SetSelection( 0 )
gbSizer1.Add( self.m_choice3, wx.GBPosition( 1, 0 ), wx.GBSpan( 2, 8 ), wx.ALL, 5 )
self.SetSizer( gbSizer1 )
self.Layout()
self.Centre( wx.BOTH )
self.m_button3.Bind( wx.EVT_BUTTON, self.Go )
def __del__( self ):
pass
def Go( self, event ):
idxs = []
for i, checkBox in enumerate(self.m_checkBoxes, start=1):
if checkBox.IsChecked():
idxs.append(i)
all_points = self.m_checkBox26.IsChecked()
if (len(idxs)<6) & (not all_points):
wx.MessageBox('Точек должно быть больше 5', 'Ошибка', wx.OK)
event.Skip()
return
if all_points:
n = 25
else:
n = len(idxs)
self.m_staticText7.SetLabel(str(n))
img_name = self.m_choice3.GetString(self.m_choice3.GetSelection())
send = self.radioBox_send.GetString(self.radioBox_send.GetSelection())
plot = self.radioBox_plot.GetString(self.radioBox_plot.GetSelection())
img_id = int(self.images[self.images.name==img_name].id.values[0])
out = self.func(idxs, img_id, img_name, 1, send=(send=='True'), plot=(plot=='True'), all_points=all_points)
if out == 'Error_empty_request':
wx.MessageBox('Выберите plot или send', 'Ошибка', wx.OK)
elif out == 'Error_annotations':
wx.MessageBox('На этом изображении нет разметки', 'Ошибка', wx.OK)
event.Skip()
return
if __name__ == '__main__':
point_cloud_path = r''
keypoints_3d_path = r'karussel_24kps.csv'
local_dataset_path = r''
token = ''
dataset_id = 627375 #(images)
project_id = 184347 #(Nurburg-karussel)
sp = Start_annotation(project_id, dataset_id, token,
local_dataset_path, keypoints_3d_path, point_cloud_path)
sp.start()