so3.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. import sympy
  2. import sys
  3. import unittest
  4. import sophus
  5. import functools
  6. class So3:
  7. """ 3 dimensional group of orthogonal matrices with determinant 1 """
  8. def __init__(self, q):
  9. """ internally represented by a unit quaternion q """
  10. self.q = q
  11. @staticmethod
  12. def exp(v):
  13. """ exponential map """
  14. theta_sq = sophus.squared_norm(v)
  15. theta = sympy.sqrt(theta_sq)
  16. return So3(
  17. sophus.Quaternion(
  18. sympy.cos(0.5 * theta),
  19. sympy.sin(0.5 * theta) / theta * v))
  20. def log(self):
  21. """ logarithmic map"""
  22. n = sympy.sqrt(sophus.squared_norm(self.q.vec))
  23. return 2 * sympy.atan(n / self.q.real) / n * self.q.vec
  24. def __repr__(self):
  25. return "So3:" + repr(self.q)
  26. def inverse(self):
  27. return So3(self.q.conj())
  28. @staticmethod
  29. def hat(o):
  30. return sympy.Matrix([[0, -o[2], o[1]],
  31. [o[2], 0, -o[0]],
  32. [-o[1], o[0], 0]])
  33. """vee-operator
  34. It takes the 3x3-matrix representation ``Omega`` and maps it to the
  35. corresponding vector representation of Lie algebra.
  36. This is the inverse of the hat-operator, see above.
  37. Precondition: ``Omega`` must have the following structure:
  38. | 0 -c b |
  39. | c 0 -a |
  40. | -b a 0 |
  41. """
  42. @staticmethod
  43. def vee(Omega):
  44. v = sophus.Vector3(Omega.row(2).col(1), Omega.row(0).col(2), Omega.row(1).col(0))
  45. return v
  46. def matrix(self):
  47. """ returns matrix representation """
  48. return sympy.Matrix([[
  49. 1 - 2 * self.q.vec[1]**2 - 2 * self.q.vec[2]**2,
  50. 2 * self.q.vec[0] * self.q.vec[1] -
  51. 2 * self.q.vec[2] * self.q[3],
  52. 2 * self.q.vec[0] * self.q.vec[2] +
  53. 2 * self.q.vec[1] * self.q[3]
  54. ], [
  55. 2 * self.q.vec[0] * self.q.vec[1] +
  56. 2 * self.q.vec[2] * self.q[3],
  57. 1 - 2 * self.q.vec[0]**2 - 2 * self.q.vec[2]**2,
  58. 2 * self.q.vec[1] * self.q.vec[2] -
  59. 2 * self.q.vec[0] * self.q[3]
  60. ], [
  61. 2 * self.q.vec[0] * self.q.vec[2] -
  62. 2 * self.q.vec[1] * self.q[3],
  63. 2 * self.q.vec[1] * self.q.vec[2] +
  64. 2 * self.q.vec[0] * self.q[3],
  65. 1 - 2 * self.q.vec[0]**2 - 2 * self.q.vec[1]**2
  66. ]])
  67. def __mul__(self, right):
  68. """ left-multiplication
  69. either rotation concatenation or point-transform """
  70. if isinstance(right, sympy.Matrix):
  71. assert right.shape == (3, 1), right.shape
  72. return (self.q * sophus.Quaternion(0, right) * self.q.conj()).vec
  73. elif isinstance(right, So3):
  74. return So3(self.q * right.q)
  75. assert False, "unsupported type: {0}".format(type(right))
  76. def __getitem__(self, key):
  77. return self.q[key]
  78. @staticmethod
  79. def calc_Dx_exp_x(x):
  80. return sympy.Matrix(4, 3, lambda r, c:
  81. sympy.diff(So3.exp(x)[r], x[c]))
  82. @staticmethod
  83. def Dx_exp_x_at_0():
  84. return sympy.Matrix([[0.5, 0.0, 0.0],
  85. [0.0, 0.5, 0.0],
  86. [0.0, 0.0, 0.5],
  87. [0.0, 0.0, 0.0]])
  88. @staticmethod
  89. def calc_Dx_exp_x_at_0(x):
  90. return So3.calc_Dx_exp_x(x).subs(x[0], 0).subs(x[1], 0).limit(x[2], 0)
  91. def calc_Dx_this_mul_exp_x_at_0(self, x):
  92. return sympy.Matrix(4, 3, lambda r, c:
  93. sympy.diff((self * So3.exp(x))[r], x[c]))\
  94. .subs(x[0], 0).subs(x[1], 0).limit(x[2], 0)
  95. def calc_Dx_exp_x_mul_this_at_0(self, x):
  96. return sympy.Matrix(3, 4, lambda r, c:
  97. sympy.diff((self * So3.exp(x))[c], x[r, 0]))\
  98. .subs(x[0], 0).subs(x[1], 0).limit(x[2], 0)
  99. @staticmethod
  100. def Dxi_x_matrix(x, i):
  101. if i == 0:
  102. return sympy.Matrix([[0, 2 * x[1], 2 * x[2]],
  103. [2 * x[1], -4 * x[0], -2 * x[3]],
  104. [2 * x[2], 2 * x[3], -4 * x[0]]])
  105. if i == 1:
  106. return sympy.Matrix([[-4 * x[1], 2 * x[0], 2 * x[3]],
  107. [2 * x[0], 0, 2 * x[2]],
  108. [-2 * x[3], 2 * x[2], -4 * x[1]]])
  109. if i == 2:
  110. return sympy.Matrix([[-4 * x[2], -2 * x[3], 2 * x[0]],
  111. [2 * x[3], -4 * x[2], 2 * x[1]],
  112. [2 * x[0], 2 * x[1], 0]])
  113. if i == 3:
  114. return sympy.Matrix([[0, -2 * x[2], 2 * x[1]],
  115. [2 * x[2], 0, -2 * x[0]],
  116. [-2 * x[1], 2 * x[0], 0]])
  117. @staticmethod
  118. def calc_Dxi_x_matrix(x, i):
  119. return sympy.Matrix(3, 3, lambda r, c:
  120. sympy.diff(x.matrix()[r, c], x[i]))
  121. @staticmethod
  122. def Dxi_exp_x_matrix(x, i):
  123. R = So3.exp(x)
  124. Dx_exp_x = So3.calc_Dx_exp_x(x)
  125. l = [Dx_exp_x[j, i] * So3.Dxi_x_matrix(R, j) for j in [0, 1, 2, 3]]
  126. return functools.reduce((lambda a, b: a + b), l)
  127. @staticmethod
  128. def calc_Dxi_exp_x_matrix(x, i):
  129. return sympy.Matrix(3, 3, lambda r, c:
  130. sympy.diff(So3.exp(x).matrix()[r, c], x[i]))
  131. @staticmethod
  132. def Dxi_exp_x_matrix_at_0(i):
  133. v = sophus.ZeroVector3()
  134. v[i] = 1
  135. return So3.hat(v)
  136. @staticmethod
  137. def calc_Dxi_exp_x_matrix_at_0(x, i):
  138. return sympy.Matrix(3, 3, lambda r, c:
  139. sympy.diff(So3.exp(x).matrix()[r, c], x[i])
  140. ).subs(x[0], 0).subs(x[1], 0).limit(x[2], 0)
  141. class TestSo3(unittest.TestCase):
  142. def setUp(self):
  143. omega0, omega1, omega2 = sympy.symbols(
  144. 'omega[0], omega[1], omega[2]', real=True)
  145. x, v0, v1, v2 = sympy.symbols('q.w() q.x() q.y() q.z()', real=True)
  146. p0, p1, p2 = sympy.symbols('p0 p1 p2', real=True)
  147. v = sophus.Vector3(v0, v1, v2)
  148. self.omega = sophus.Vector3(omega0, omega1, omega2)
  149. self.a = So3(sophus.Quaternion(x, v))
  150. self.p = sophus.Vector3(p0, p1, p2)
  151. def test_exp_log(self):
  152. for o in [sophus.Vector3(0., 1, 0.5),
  153. sophus.Vector3(0.1, 0.1, 0.1),
  154. sophus.Vector3(0.01, 0.2, 0.03)]:
  155. w = So3.exp(o).log()
  156. for i in range(0, 3):
  157. self.assertAlmostEqual(o[i], w[i])
  158. def test_matrix(self):
  159. R_foo_bar = So3.exp(self.omega)
  160. Rmat_foo_bar = R_foo_bar.matrix()
  161. point_bar = self.p
  162. p1_foo = R_foo_bar * point_bar
  163. p2_foo = Rmat_foo_bar * point_bar
  164. self.assertEqual(sympy.simplify(p1_foo - p2_foo),
  165. sophus.ZeroVector3())
  166. def test_derivatives(self):
  167. self.assertEqual(sympy.simplify(So3.calc_Dx_exp_x_at_0(self.omega) -
  168. So3.Dx_exp_x_at_0()),
  169. sympy.Matrix.zeros(4, 3))
  170. for i in [0, 1, 2, 3]:
  171. self.assertEqual(sympy.simplify(So3.calc_Dxi_x_matrix(self.a, i) -
  172. So3.Dxi_x_matrix(self.a, i)),
  173. sympy.Matrix.zeros(3, 3))
  174. for i in [0, 1, 2]:
  175. self.assertEqual(sympy.simplify(
  176. So3.Dxi_exp_x_matrix(self.omega, i) -
  177. So3.calc_Dxi_exp_x_matrix(self.omega, i)),
  178. sympy.Matrix.zeros(3, 3))
  179. self.assertEqual(sympy.simplify(
  180. So3.Dxi_exp_x_matrix_at_0(i) -
  181. So3.calc_Dxi_exp_x_matrix_at_0(self.omega, i)),
  182. sympy.Matrix.zeros(3, 3))
  183. def test_codegen(self):
  184. stream = sophus.cse_codegen(So3.calc_Dx_exp_x(self.omega))
  185. filename = "cpp_gencode/So3_Dx_exp_x.cpp"
  186. # set to true to generate codegen files
  187. if False:
  188. file = open(filename, "w")
  189. for line in stream:
  190. file.write(line)
  191. file.close()
  192. else:
  193. file = open(filename, "r")
  194. file_lines = file.readlines()
  195. for i, line in enumerate(stream):
  196. self.assertEqual(line, file_lines[i])
  197. file.close()
  198. stream.close
  199. stream = sophus.cse_codegen(
  200. self.a.calc_Dx_this_mul_exp_x_at_0(self.omega))
  201. filename = "cpp_gencode/So3_Dx_this_mul_exp_x_at_0.cpp"
  202. # set to true to generate codegen files
  203. if False:
  204. file = open(filename, "w")
  205. for line in stream:
  206. file.write(line)
  207. file.close()
  208. else:
  209. file = open(filename, "r")
  210. file_lines = file.readlines()
  211. for i, line in enumerate(stream):
  212. self.assertEqual(line, file_lines[i])
  213. file.close()
  214. stream.close
  215. if __name__ == '__main__':
  216. unittest.main()