123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219 |
- import numpy as np
- import csv
- import re
- import argparse
- import sys
- comment_pattern = re.compile(r'\s*#.*$')
- class Position:
- def __init__(self, timestamp, pos=None, rot_quat=None):
- self.timestamp = timestamp
- self.is_kf = False
- if pos is None and rot_quat is None:
- self.exist = False
- self.pos = pos # x,y,z
- self.rot_quat = rot_quat # w,x,y,z
- else:
- self.exist = True
- self.pos = pos # x,y,z
- self.rot_quat = rot_quat # w,x,y,z
- class Trajectory:
- def __init__(self, lst_poses=None):
- if lst_poses is not None:
- self.num_poses = len(lst_poses)
- self.timestamps = np.zeros((self.num_poses, 1))
- self.pos = np.zeros((self.num_poses, 3))
- self.q = np.zeros((self.num_poses, 4))
- for i in range(0, self.num_poses):
- self.timestamps[i] = lst_poses[i].timestamp
- if lst_poses[i].rot_quat is not None:
- for j in range(0, 4):
- self.q[i][j] = lst_poses[i].rot_quat[j]
- for j in range(0, 3):
- self.pos[i][j] = lst_poses[i].pos[j]
- # self.pos = np.transpose(self.pos)
- self.pos = np.transpose(self.pos)
- self.q = np.transpose(self.q)
- def load_file(path_file_name: str, delimiter=" "):
- lst_poses = []
- with open(path_file_name) as file:
- file = skip_comments(file)
- for line in file:
- vector = line.split(delimiter)
- # print(len(vector))
- if len(vector) == 8:
- # Correct position
- timestamp = float(vector[0])
- pos = [float(vector[1]), float(vector[2]), float(vector[3])]
- rot_quat = [float(vector[7]), float(vector[4]), float(vector[5]), float(vector[6])]
- pose = Position(timestamp, pos, rot_quat)
- lst_poses.append(pose)
- else:
- # We haven't got a position in this
- timestamp = float(vector[0])
- lst_poses.append(Position(timestamp))
- return lst_poses
- def associate(poses_test, poses_gt):
- num_match = 0
- i_test = 0
- i_gt = 0
- lst_check_test = []
- lst_check_gt = []
- while i_test < len(poses_test):
- if i_gt == len(poses_gt):
- break
- lst_check_test.append(poses_test[i_test])
- lst_check_gt.append(poses_gt[i_gt-1])
- num_match = num_match + 1
- i_test = i_test + 1
- elif np.abs(poses_test[i_test].timestamp - poses_gt[i_gt].timestamp)/1e3 < 1:
- lst_check_test.append(poses_test[i_test])
- lst_check_gt.append(poses_gt[i_gt])
- num_match = num_match + 1
- i_test = i_test + 1
- i_gt = i_gt + 1
- elif poses_test[i_test].timestamp < poses_gt[i_gt].timestamp:
- i_test = i_test + 1
- if i_test >= len(poses_test):
- break
- elif poses_test[i_test].timestamp > poses_gt[i_gt].timestamp:
- i_gt = i_gt + 1
- if i_gt >= len(poses_gt):
- break
- return lst_check_test, lst_check_gt
- def create_trajectory(lst_poses_test, lst_poses_gt):
- lst_exist_poses_test = []
- lst_exist_poses_gt = []
- for i in range(len(lst_poses_test)):
- if lst_poses_test[i].exist:
- lst_exist_poses_test.append(lst_poses_test[i])
- pose_gt = lst_poses_gt[i]
- lst_exist_poses_gt.append(pose_gt)
- traj_test = Trajectory(lst_exist_poses_test)
- traj_gt = Trajectory(lst_exist_poses_gt)
- return traj_test, traj_gt
- def skip_comments(lines):
- """
- A filter which skip/strip the comments and yield the
- rest of the lines
- :param lines: any object which we can iterate through such as a file
- object, list, tuple, or generator
- """
- global comment_pattern
- for line in lines:
- line = re.sub(comment_pattern, '', line).strip()
- if line:
- yield line
- def align(traj_test, traj_gt):
- mean_test = traj_test.pos.mean(1)
- mean_test = mean_test.reshape(mean_test.size, 1)
- mean_gt = traj_gt.pos.mean(1)
- mean_gt = mean_gt.reshape(mean_gt.size, 1)
- traj_centered_test = Trajectory()
- traj_centered_gt = Trajectory()
- traj_centered_test.pos = traj_test.pos - mean_test
- traj_centered_gt.pos = traj_gt.pos - mean_gt
- L = traj_centered_test.pos.shape[1]
- W = np.zeros((3, 3))
- for col in range(traj_centered_test.pos.shape[1]):
- W += np.outer(traj_centered_test.pos[:, col], traj_centered_gt.pos[:, col])
- U, d, Vh = np.linalg.linalg.svd(W.transpose())
- S = np.matrix(np.identity(3))
- if np.linalg.det(U) * np.linalg.det(Vh) < 0:
- # print("Negative")
- S[2, 2] = -1
- Rot = U * S * Vh
- # pose_rot_cent_test = np.matmul(Rot, traj_centered_test.pos)
- pose_rot_cent_test = Rot * np.matrix(traj_centered_test.pos)
- # print('Size center: {}'.format(pose_rot_cent_test.shape))
- # print('Size center2: {}'.format((Rot * np.matrix(traj_centered_test.pos)).shape))
- s = 1.0
- dots = 0.0
- norms = 0.0
- for column in range(traj_centered_gt.pos.shape[1]):
- dots += np.dot(traj_centered_gt.pos[:, column].transpose(), pose_rot_cent_test[:, column])
- normi = np.linalg.norm(traj_centered_test.pos[:, column])
- norms += normi * normi
- s = float(dots / norms)
- traj_length = 0
- for i in range(1, traj_centered_gt.pos.shape[1]):
- # print("Pos = ", _trajData.pos[:, i])
- traj_length = traj_length + np.linalg.norm(traj_centered_gt.pos[:, i] - traj_centered_gt.pos[:, i - 1])
- # print("Rotation matrix between model and GT: \n", Rot)
- trans_scale = mean_gt - s * np.matmul(Rot, mean_test)
- traj_aligned_scale = Trajectory()
- traj_aligned_scale.pos = s * Rot * np.matrix(traj_test.pos) + trans_scale
- trans = mean_gt - np.matmul(Rot, mean_test)
- traj_aligned = Trajectory()
- traj_aligned.pos = Rot * np.matrix(traj_test.pos) + trans
- # np.reshape(_traj_test.pos, (3, L))
- # print(_traj_test.pos.shape)
- error_traj_scale = traj_aligned_scale.pos - traj_gt.pos
- error_traj = traj_aligned.pos - traj_gt.pos
- ate_scale = np.squeeze(np.asarray(np.sqrt(np.sum(np.multiply(error_traj_scale, error_traj_scale), 0))))
- ateRMSE_scale = np.sqrt(np.dot(ate_scale, ate_scale) / len(ate_scale))
- ate = np.squeeze(np.asarray(np.sqrt(np.sum(np.multiply(error_traj, error_traj), 0))))
- ateRMSE = np.sqrt(np.dot(ate, ate) / len(ate))
- # Print the error
- print("ATE RMSE(7DoF): " + str(ateRMSE_scale))
- print("scale: %f " % s)
- print("ATE RMSE(6DoF): " + str(ateRMSE))
- if __name__ == '__main__':
- parser = argparse.ArgumentParser(description='''
- This script computes the absolute trajectory error from the ground truth trajectory and the estimated trajectory.
- ''')
- parser.add_argument('SLAM', help='estimated trajectory (format: timestamp tx ty tz qx qy qz qw)')
- parser.add_argument('GT', help='ground truth trajectory (format: timestamp tx ty tz qx qy qz qw)')
- args = parser.parse_args()
- str_file_gt = args.GT
- str_file_traj = args.SLAM
- lst_gt_poses = load_file(str_file_gt, ",")
- lst_traj_poses = load_file(str_file_traj, " ")
- lst_traj_poses2, lst_gt_poses2 = associate(lst_traj_poses, lst_gt_poses)
- traj_test, traj_gt = create_trajectory(lst_traj_poses2, lst_gt_poses2)
- align(traj_test, traj_gt)
|