cse_codegen.py 473 B

1234567891011121314151617
  1. import sympy
  2. import io
  3. def cse_codegen(symbols):
  4. cse_results = sympy.cse(symbols, sympy.numbered_symbols("c"))
  5. output = io.StringIO()
  6. for helper in cse_results[0]:
  7. output.write("Scalar const ")
  8. output.write(sympy.printing.ccode(helper[1], helper[0]))
  9. output.write("\n")
  10. assert len(cse_results[1]) == 1
  11. output.write(sympy.printing.ccode(cse_results[1][0], "result"))
  12. output.write("\n")
  13. output.seek(0)
  14. return output