complex.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. import sophus
  2. import sympy
  3. import sys
  4. import unittest
  5. class Complex:
  6. """ Complex class """
  7. def __init__(self, real, imag):
  8. self.real = real
  9. self.imag = imag
  10. def __mul__(self, right):
  11. """ complex multiplication """
  12. return Complex(self.real * right.real - self.imag * right.imag,
  13. self.imag * right.real + self.real * right.imag)
  14. def __add__(self, right):
  15. return Complex(elf.real + right.real, self.imag + right.imag)
  16. def __neg__(self):
  17. return Complex(-self.real, -self.image)
  18. def __truediv__(self, scalar):
  19. """ scalar division """
  20. return Complex(self.real / scalar, self.imag / scalar)
  21. def __repr__(self):
  22. return "( " + repr(self.real) + " + " + repr(self.imag) + "i )"
  23. def __getitem__(self, key):
  24. """ We use the following convention [real, imag] """
  25. if key == 0:
  26. return self.real
  27. else:
  28. return self.imag
  29. def squared_norm(self):
  30. """ squared norm when considering the complex number as tuple """
  31. return self.real**2 + self.imag**2
  32. def conj(self):
  33. """ complex conjugate """
  34. return Complex(self.real, -self.imag)
  35. def inv(self):
  36. """ complex inverse """
  37. return self.conj() / self.squared_norm()
  38. @staticmethod
  39. def identity():
  40. return Complex(1, 0)
  41. @staticmethod
  42. def zero():
  43. return Complex(0, 0)
  44. def __eq__(self, other):
  45. if isinstance(self, other.__class__):
  46. return self.real == other.real and self.imag == other.imag
  47. return False
  48. def subs(self, x, y):
  49. return Complex(self.real.subs(x, y), self.imag.subs(x, y))
  50. def simplify(self):
  51. return Complex(sympy.simplify(self.real),
  52. sympy.simplify(self.imag))
  53. @staticmethod
  54. def Da_a_mul_b(a, b):
  55. """ derivatice of complex muliplication wrt left multiplier a """
  56. return sympy.Matrix([[b.real, -b.imag],
  57. [b.imag, b.real]])
  58. @staticmethod
  59. def Db_a_mul_b(a, b):
  60. """ derivatice of complex muliplication wrt right multiplicand b """
  61. return sympy.Matrix([[a.real, -a.imag],
  62. [a.imag, a.real]])
  63. class TestComplex(unittest.TestCase):
  64. def setUp(self):
  65. x, y = sympy.symbols('x y', real=True)
  66. u, v = sympy.symbols('u v', real=True)
  67. self.a = Complex(x, y)
  68. self.b = Complex(u, v)
  69. def test_muliplications(self):
  70. product = self.a * self.a.inv()
  71. self.assertEqual(product.simplify(),
  72. Complex.identity())
  73. product = self.a.inv() * self.a
  74. self.assertEqual(product.simplify(),
  75. Complex.identity())
  76. def test_derivatives(self):
  77. d = sympy.Matrix(2, 2, lambda r, c: sympy.diff(
  78. (self.a * self.b)[r], self.a[c]))
  79. self.assertEqual(d,
  80. Complex.Da_a_mul_b(self.a, self.b))
  81. d = sympy.Matrix(2, 2, lambda r, c: sympy.diff(
  82. (self.a * self.b)[r], self.b[c]))
  83. self.assertEqual(d,
  84. Complex.Db_a_mul_b(self.a, self.b))
  85. if __name__ == '__main__':
  86. unittest.main()
  87. print('hello')