quaternion.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. """ run with: python3 -m sophus.quaternion """
  2. import sophus
  3. import sympy
  4. import sys
  5. import unittest
  6. class Quaternion:
  7. """ Quaternion class """
  8. def __init__(self, real, vec):
  9. """ Quaternion consists of a real scalar, and an imaginary 3-vector """
  10. assert isinstance(vec, sympy.Matrix)
  11. assert vec.shape == (3, 1), vec.shape
  12. self.real = real
  13. self.vec = vec
  14. def __mul__(self, right):
  15. """ quaternion multiplication """
  16. return Quaternion(self[3] * right[3] - self.vec.dot(right.vec),
  17. self[3] * right.vec + right[3] * self.vec +
  18. self.vec.cross(right.vec))
  19. def __add__(self, right):
  20. """ quaternion multiplication """
  21. return Quaternion(self[3] + right[3], self.vec + right.vec)
  22. def __neg__(self):
  23. return Quaternion(-self[3], -self.vec)
  24. def __truediv__(self, scalar):
  25. """ scalar division """
  26. return Quaternion(self.real / scalar, self.vec / scalar)
  27. def __repr__(self):
  28. return "( " + repr(self[3]) + " + " + repr(self.vec) + "i )"
  29. def __getitem__(self, key):
  30. """ We use the following convention [vec0, vec1, vec2, real] """
  31. assert (key >= 0 and key < 4)
  32. if key == 3:
  33. return self.real
  34. else:
  35. return self.vec[key]
  36. def squared_norm(self):
  37. """ squared norm when considering the quaternion as 4-tuple """
  38. return sophus.squared_norm(self.vec) + self.real**2
  39. def conj(self):
  40. """ quaternion conjugate """
  41. return Quaternion(self.real, -self.vec)
  42. def inv(self):
  43. """ quaternion inverse """
  44. return self.conj() / self.squared_norm()
  45. @staticmethod
  46. def identity():
  47. return Quaternion(1, sophus.Vector3(0, 0, 0))
  48. @staticmethod
  49. def zero():
  50. return Quaternion(0, sophus.Vector3(0, 0, 0))
  51. def subs(self, x, y):
  52. return Quaternion(self.real.subs(x, y), self.vec.subs(x, y))
  53. def simplify(self):
  54. v = sympy.simplify(self.vec)
  55. return Quaternion(sympy.simplify(self.real),
  56. sophus.Vector3(v[0], v[1], v[2]))
  57. def __eq__(self, other):
  58. if isinstance(self, other.__class__):
  59. return self.real == other.real and self.vec == other.vec
  60. return False
  61. @staticmethod
  62. def Da_a_mul_b(a, b):
  63. """ derivatice of quaternion muliplication wrt left multiplier a """
  64. v0 = b.vec[0]
  65. v1 = b.vec[1]
  66. v2 = b.vec[2]
  67. y = b.real
  68. return sympy.Matrix([[y, v2, -v1, v0],
  69. [-v2, y, v0, v1],
  70. [v1, -v0, y, v2],
  71. [-v0, -v1, -v2, y]])
  72. @staticmethod
  73. def Db_a_mul_b(a, b):
  74. """ derivatice of quaternion muliplication wrt right multiplicand b """
  75. u0 = a.vec[0]
  76. u1 = a.vec[1]
  77. u2 = a.vec[2]
  78. x = a.real
  79. return sympy.Matrix([[x, -u2, u1, u0],
  80. [u2, x, -u0, u1],
  81. [-u1, u0, x, u2],
  82. [-u0, -u1, -u2, x]])
  83. class TestQuaternion(unittest.TestCase):
  84. def setUp(self):
  85. x, u0, u1, u2 = sympy.symbols('x u0 u1 u2', real=True)
  86. y, v0, v1, v2 = sympy.symbols('y v0 v1 v2', real=True)
  87. u = sophus.Vector3(u0, u1, u2)
  88. v = sophus.Vector3(v0, v1, v2)
  89. self.a = Quaternion(x, u)
  90. self.b = Quaternion(y, v)
  91. def test_muliplications(self):
  92. product = self.a * self.a.inv()
  93. self.assertEqual(product.simplify(),
  94. Quaternion.identity())
  95. product = self.a.inv() * self.a
  96. self.assertEqual(product.simplify(),
  97. Quaternion.identity())
  98. def test_derivatives(self):
  99. d = sympy.Matrix(4, 4, lambda r, c: sympy.diff(
  100. (self.a * self.b)[r], self.a[c]))
  101. self.assertEqual(d,
  102. Quaternion.Da_a_mul_b(self.a, self.b))
  103. d = sympy.Matrix(4, 4, lambda r, c: sympy.diff(
  104. (self.a * self.b)[r], self.b[c]))
  105. self.assertEqual(d,
  106. Quaternion.Db_a_mul_b(self.a, self.b))
  107. if __name__ == '__main__':
  108. unittest.main()
  109. print('hello')