{autolink-concat}
::::{margin}
:::{card} Square root over arrays with negative values
TR-000
^^^
This notebook investigates how to write a square root function in {mod}sympy
that computes the positive square root for negative values. The lambdified version of this 'complex square root' should have the same behavior for each computational backend.
+++
✅ tensorwaves#284
:::
::::
%pip install -q black==21.5b2 jax==0.2.13 jaxlib==0.1.67 numpy==1.23 sympy==1.8
import inspect
import jax
import jax.numpy as jnp
import numpy as np
import sympy as sp
from black import FileMode, format_str
from IPython.display import display
When using {mod}numpy
as back-end, {mod}sympy
lambdifies a {func}~sympy.functions.elementary.miscellaneous.sqrt
to a {obj}numpy.sqrt
:
x = sp.Symbol("x")
sqrt_expr = sp.sqrt(x)
sqrt_expr
np_sqrt = sp.lambdify(x, sqrt_expr, "numpy")
source = inspect.getsource(np_sqrt)
print(source)
def _lambdifygenerated(x): return (sqrt(x))
As expected, if input values for the {obj}numpy.sqrt
are negative, {mod}numpy
raises a {class}RuntimeWarning
and returns NaN
:
sample = np.linspace(-1, 1, 5)
np_sqrt(sample)
<lambdifygenerated-1>:2: RuntimeWarning: invalid value encountered in sqrt return (sqrt(x))
array([ nan, nan, 0. , 0.70710678, 1. ])
If we want {mod}numpy
to return imaginary numbers for negative input values, one can use {class}complex
input data instead (e.g. {doc}numpy.complex64 <numpy:reference/arrays.scalars>
). Negative values are then treated as lying just above the real axis, so that their square root is a positive imaginary number:
complex_sample = sample.astype(np.complex64)
np_sqrt(complex_sample)
array([0. +1.j , 0. +0.70710677j, 0. +0.j , 0.70710677+0.j , 1. +0.j ], dtype=complex64)
A {func}sympy.sqrt <sympy.functions.elementary.miscellaneous.sqrt>
lambdified to JAX exhibits the same behavior:
jax_sqrt = jax.jit(sp.lambdify(x, sqrt_expr, jnp))
source = inspect.getsource(jax_sqrt)
print(source)
def _lambdifygenerated(x): return (sqrt(x))
jax_sqrt(sample)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
DeviceArray([ nan, nan, 0. , 0.70710677, 1. ], dtype=float32)
jax_sqrt(complex_sample)
DeviceArray([-4.3711388e-08+1.j , -3.0908620e-08+0.70710677j, 0.0000000e+00+0.j , 7.0710677e-01+0.j , 1.0000000e+00+0.j ], dtype=complex64)
There is a problem with this approach though: once input data is complex, all square roots in a larger expression (some amplitude model) compute imaginary solutions for negative values, while this is not always the desired behavior.
Take for instance the two square roots appearing in {class}~ampform.dynamics.phasespace.PhaseSpaceFactor
--- does the $\sqrt{s}$ also have to be evaluatable for negative $s$?
Numpy also offers a special function that evaluates negative values even if the input values are real: {func}numpy.emath.sqrt
:
np.emath.sqrt(-1)
1j
Unfortunately, the {mod}jax.numpy
API does not interface to {mod}numpy.emath
. It is possible to decorate {func}numpy.emath.sqrt
be decorated with {func}jax.jit
, but that only works with static, hashable arguments:
jax_csqrt_error = jax.jit(np.emath.sqrt, backend="cpu")
jax_csqrt_error(-1)
--------------------------------------------------------------------------- UnfilteredStackTrace Traceback (most recent call last) File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/ipykernel_launcher.py:17, in <module> 15 from ipykernel import kernelapp as app ---> 17 app.launch_new_instance() File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/traitlets/config/application.py:976, in Application.launch_instance(cls, argv, **kwargs) 975 app.initialize(argv) --> 976 app.start() File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/ipykernel/kernelapp.py:712, in IPKernelApp.start(self) 711 try: --> 712 self.io_loop.start() 713 except KeyboardInterrupt: File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/tornado/platform/asyncio.py:199, in BaseAsyncIOLoop.start(self) 198 asyncio.set_event_loop(self.asyncio_loop) --> 199 self.asyncio_loop.run_forever() 200 finally: File ~/miniconda3/envs/compwa-org/lib/python3.8/asyncio/base_events.py:570, in BaseEventLoop.run_forever(self) 569 while True: --> 570 self._run_once() 571 if self._stopping: File ~/miniconda3/envs/compwa-org/lib/python3.8/asyncio/base_events.py:1859, in BaseEventLoop._run_once(self) 1858 else: -> 1859 handle._run() 1860 handle = None File ~/miniconda3/envs/compwa-org/lib/python3.8/asyncio/events.py:81, in Handle._run(self) 80 try: ---> 81 self._context.run(self._callback, *self._args) 82 except (SystemExit, KeyboardInterrupt): File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/ipykernel/kernelbase.py:510, in Kernel.dispatch_queue(self) 509 try: --> 510 await self.process_one() 511 except Exception: File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/ipykernel/kernelbase.py:499, in Kernel.process_one(self, wait) 498 return None --> 499 await dispatch(*args) File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/ipykernel/kernelbase.py:406, in Kernel.dispatch_shell(self, msg) 405 if inspect.isawaitable(result): --> 406 await result 407 except Exception: File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/ipykernel/kernelbase.py:730, in Kernel.execute_request(self, stream, ident, parent) 729 if inspect.isawaitable(reply_content): --> 730 reply_content = await reply_content 732 # Flush output before sending the reply. File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/ipykernel/ipkernel.py:383, in IPythonKernel.do_execute(self, code, silent, store_history, user_expressions, allow_stdin, cell_id) 382 if with_cell_id: --> 383 res = shell.run_cell( 384 code, 385 store_history=store_history, 386 silent=silent, 387 cell_id=cell_id, 388 ) 389 else: File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/ipykernel/zmqshell.py:528, in ZMQInteractiveShell.run_cell(self, *args, **kwargs) 527 self._last_traceback = None --> 528 return super().run_cell(*args, **kwargs) File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/IPython/core/interactiveshell.py:2881, in InteractiveShell.run_cell(self, raw_cell, store_history, silent, shell_futures, cell_id) 2880 try: -> 2881 result = self._run_cell( 2882 raw_cell, store_history, silent, shell_futures, cell_id 2883 ) 2884 finally: File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/IPython/core/interactiveshell.py:2936, in InteractiveShell._run_cell(self, raw_cell, store_history, silent, shell_futures, cell_id) 2935 try: -> 2936 return runner(coro) 2937 except BaseException as e: File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/IPython/core/async_helpers.py:129, in _pseudo_sync_runner(coro) 128 try: --> 129 coro.send(None) 130 except StopIteration as exc: File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/IPython/core/interactiveshell.py:3135, in InteractiveShell.run_cell_async(self, raw_cell, store_history, silent, shell_futures, transformed_cell, preprocessing_exc_tuple, cell_id) 3133 interactivity = "none" if silent else self.ast_node_interactivity -> 3135 has_raised = await self.run_ast_nodes(code_ast.body, cell_name, 3136 interactivity=interactivity, compiler=compiler, result=result) 3138 self.last_execution_succeeded = not has_raised File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/IPython/core/interactiveshell.py:3338, in InteractiveShell.run_ast_nodes(self, nodelist, cell_name, interactivity, compiler, result) 3337 asy = compare(code) -> 3338 if await self.run_code(code, result, async_=asy): 3339 return True [... skipping hidden 1 frame] Input In [13], in <cell line: 2>() 1 jax_csqrt_error = jax.jit(np.emath.sqrt, backend="cpu") ----> 2 jax_csqrt_error(-1) File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/_src/traceback_util.py:143, in api_boundary.<locals>.reraise_with_filtered_traceback(*args, **kwargs) 142 try: --> 143 return fun(*args, **kwargs) 144 except Exception as e: File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/_src/api.py:426, in _cpp_jit.<locals>.cache_miss(*args, **kwargs) 425 flat_fun, out_tree = flatten_fun(f, in_tree) --> 426 out_flat = xla.xla_call( 427 flat_fun, 428 *args_flat, 429 device=device, 430 backend=backend, 431 name=flat_fun.__name__, 432 donated_invars=donated_invars) 433 out_pytree_def = out_tree() File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/core.py:1565, in CallPrimitive.bind(self, fun, *args, **params) 1564 def bind(self, fun, *args, **params): -> 1565 return call_bind(self, fun, *args, **params) File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/core.py:1556, in call_bind(primitive, fun, *args, **params) 1555 with maybe_new_sublevel(top_trace): -> 1556 outs = primitive.process(top_trace, fun, tracers, params) 1557 return map(full_lower, apply_todos(env_trace_todo(), outs)) File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/core.py:1568, in CallPrimitive.process(self, trace, fun, tracers, params) 1567 def process(self, trace, fun, tracers, params): -> 1568 return trace.process_call(self, fun, tracers, params) File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/core.py:609, in EvalTrace.process_call(self, primitive, f, tracers, params) 608 def process_call(self, primitive, f, tracers, params): --> 609 return primitive.impl(f, *tracers, **params) File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/interpreters/xla.py:578, in _xla_call_impl(fun, device, backend, name, donated_invars, *args) 577 def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name, donated_invars): --> 578 compiled_fun = _xla_callable(fun, device, backend, name, donated_invars, 579 *unsafe_map(arg_spec, args)) 580 try: File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/linear_util.py:262, in cache.<locals>.memoized_fun(fun, *args) 261 else: --> 262 ans = call(fun, *args) 263 cache[key] = (ans, fun.stores) File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/interpreters/xla.py:652, in _xla_callable(fun, device, backend, name, donated_invars, *arg_specs) 651 abstract_args, _ = unzip2(arg_specs) --> 652 jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args, transform_name="jit") 653 if any(isinstance(c, core.Tracer) for c in consts): File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/interpreters/partial_eval.py:1209, in trace_to_jaxpr_final(fun, in_avals, transform_name) 1208 main.jaxpr_stack = () # type: ignore -> 1209 jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals) 1210 del fun, main File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/interpreters/partial_eval.py:1188, in trace_to_subjaxpr_dynamic(fun, main, in_avals) 1187 in_tracers = map(trace.new_arg, in_avals) -> 1188 ans = fun.call_wrapped(*in_tracers) 1189 out_tracers = map(trace.full_raise, ans) File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/linear_util.py:166, in WrappedFun.call_wrapped(self, *args, **kwargs) 165 try: --> 166 ans = self.f(*args, **dict(self.params, **kwargs)) 167 except: 168 # Some transformations yield from inside context managers, so we have to 169 # interrupt them before reraising the exception. Otherwise they will only 170 # get garbage-collected at some later time, running their cleanup tasks only 171 # after this exception is handled, which can corrupt the global state. File <__array_function__ internals>:180, in sqrt(*args, **kwargs) File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/numpy/lib/scimath.py:247, in sqrt(x) 200 """ 201 Compute the square root of x. 202 (...) 245 -2j 246 """ --> 247 x = _fix_real_lt_zero(x) 248 return nx.sqrt(x) File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/numpy/lib/scimath.py:134, in _fix_real_lt_zero(x) 113 """Convert `x` to complex if it has real, negative components. 114 115 Otherwise, output is just the array version of the input (via asarray). (...) 132 133 """ --> 134 x = asarray(x) 135 if any(isreal(x) & (x < 0)): File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/core.py:472, in Tracer.__array__(self, *args, **kw) 471 def __array__(self, *args, **kw): --> 472 raise TracerArrayConversionError(self) UnfilteredStackTrace: jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)> (https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError) The stack trace below excludes JAX-internal frames. The preceding is the original exception that occurred, unmodified. -------------------- The above exception was the direct cause of the following exception: TracerArrayConversionError Traceback (most recent call last) Input In [13], in <cell line: 2>() 1 jax_csqrt_error = jax.jit(np.emath.sqrt, backend="cpu") ----> 2 jax_csqrt_error(-1) File <__array_function__ internals>:180, in sqrt(*args, **kwargs) File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/numpy/lib/scimath.py:247, in sqrt(x) 198 @array_function_dispatch(_unary_dispatcher) 199 def sqrt(x): 200 """ 201 Compute the square root of x. 202 (...) 245 -2j 246 """ --> 247 x = _fix_real_lt_zero(x) 248 return nx.sqrt(x) File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/numpy/lib/scimath.py:134, in _fix_real_lt_zero(x) 112 def _fix_real_lt_zero(x): 113 """Convert `x` to complex if it has real, negative components. 114 115 Otherwise, output is just the array version of the input (via asarray). (...) 132 133 """ --> 134 x = asarray(x) 135 if any(isreal(x) & (x < 0)): 136 x = _tocomplex(x) TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)> (https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError)
jax_csqrt = jax.jit(np.emath.sqrt, backend="cpu", static_argnums=0)
jax_csqrt(-1)
DeviceArray(0.+1.j, dtype=complex64)
jax_csqrt(sample)
--------------------------------------------------------------------------- ValueError Traceback (most recent call last) Input In [15], in <cell line: 1>() ----> 1 jax_csqrt(sample) ValueError: Non-hashable static arguments are not supported. An error occured while trying to hash an object of type <class 'numpy.ndarray'>, [-1. -0.5 0. 0.5 1. ]. The error was: TypeError: unhashable type: 'numpy.ndarray'
To be able to control which square roots in the complete expression should be evaluatable for negative values, one could use {class}~sympy.functions.elementary.piecewise.Piecewise
:
def complex_sqrt(x: sp.Symbol) -> sp.Expr:
return sp.Piecewise(
(sp.sqrt(-x) * sp.I, x < 0),
(sp.sqrt(x), True),
)
complex_sqrt(x)
display(
complex_sqrt(-4),
complex_sqrt(+4),
)
Be careful though when lambdifying this expression: do not use the __dict__
of the {mod}numpy
module as backend, but use the module itself instead. When using __dict__
, {func}~sympy.utilities.lambdify.lambdify
will return an if-else
statement, which is inefficient and, worse, will result in problems with {doc}JAX <jax:index>
:
:::{warning}
Do not use the module __dict__
for the modules
argument of {func}~sympy.utilities.lambdify.lambdify
.
:::
np_complex_sqrt_no_select = sp.lambdify(x, complex_sqrt(x), np.__dict__)
source = inspect.getsource(np_complex_sqrt_no_select)
print(source)
def _lambdifygenerated(x): return (((1j*sqrt(-x)) if (x < 0) else (sqrt(x))))
np_complex_sqrt_no_select(-1)
1j
jax_complex_sqrt_no_select = jax.jit(np_complex_sqrt_no_select)
jax_complex_sqrt_no_select(-1)
--------------------------------------------------------------------------- UnfilteredStackTrace Traceback (most recent call last) File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/ipykernel_launcher.py:17, in <module> 15 from ipykernel import kernelapp as app ---> 17 app.launch_new_instance() File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/traitlets/config/application.py:976, in Application.launch_instance(cls, argv, **kwargs) 975 app.initialize(argv) --> 976 app.start() File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/ipykernel/kernelapp.py:712, in IPKernelApp.start(self) 711 try: --> 712 self.io_loop.start() 713 except KeyboardInterrupt: File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/tornado/platform/asyncio.py:199, in BaseAsyncIOLoop.start(self) 198 asyncio.set_event_loop(self.asyncio_loop) --> 199 self.asyncio_loop.run_forever() 200 finally: File ~/miniconda3/envs/compwa-org/lib/python3.8/asyncio/base_events.py:570, in BaseEventLoop.run_forever(self) 569 while True: --> 570 self._run_once() 571 if self._stopping: File ~/miniconda3/envs/compwa-org/lib/python3.8/asyncio/base_events.py:1859, in BaseEventLoop._run_once(self) 1858 else: -> 1859 handle._run() 1860 handle = None File ~/miniconda3/envs/compwa-org/lib/python3.8/asyncio/events.py:81, in Handle._run(self) 80 try: ---> 81 self._context.run(self._callback, *self._args) 82 except (SystemExit, KeyboardInterrupt): File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/ipykernel/kernelbase.py:510, in Kernel.dispatch_queue(self) 509 try: --> 510 await self.process_one() 511 except Exception: File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/ipykernel/kernelbase.py:499, in Kernel.process_one(self, wait) 498 return None --> 499 await dispatch(*args) File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/ipykernel/kernelbase.py:406, in Kernel.dispatch_shell(self, msg) 405 if inspect.isawaitable(result): --> 406 await result 407 except Exception: File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/ipykernel/kernelbase.py:730, in Kernel.execute_request(self, stream, ident, parent) 729 if inspect.isawaitable(reply_content): --> 730 reply_content = await reply_content 732 # Flush output before sending the reply. File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/ipykernel/ipkernel.py:383, in IPythonKernel.do_execute(self, code, silent, store_history, user_expressions, allow_stdin, cell_id) 382 if with_cell_id: --> 383 res = shell.run_cell( 384 code, 385 store_history=store_history, 386 silent=silent, 387 cell_id=cell_id, 388 ) 389 else: File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/ipykernel/zmqshell.py:528, in ZMQInteractiveShell.run_cell(self, *args, **kwargs) 527 self._last_traceback = None --> 528 return super().run_cell(*args, **kwargs) File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/IPython/core/interactiveshell.py:2881, in InteractiveShell.run_cell(self, raw_cell, store_history, silent, shell_futures, cell_id) 2880 try: -> 2881 result = self._run_cell( 2882 raw_cell, store_history, silent, shell_futures, cell_id 2883 ) 2884 finally: File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/IPython/core/interactiveshell.py:2936, in InteractiveShell._run_cell(self, raw_cell, store_history, silent, shell_futures, cell_id) 2935 try: -> 2936 return runner(coro) 2937 except BaseException as e: File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/IPython/core/async_helpers.py:129, in _pseudo_sync_runner(coro) 128 try: --> 129 coro.send(None) 130 except StopIteration as exc: File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/IPython/core/interactiveshell.py:3135, in InteractiveShell.run_cell_async(self, raw_cell, store_history, silent, shell_futures, transformed_cell, preprocessing_exc_tuple, cell_id) 3133 interactivity = "none" if silent else self.ast_node_interactivity -> 3135 has_raised = await self.run_ast_nodes(code_ast.body, cell_name, 3136 interactivity=interactivity, compiler=compiler, result=result) 3138 self.last_execution_succeeded = not has_raised File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/IPython/core/interactiveshell.py:3338, in InteractiveShell.run_ast_nodes(self, nodelist, cell_name, interactivity, compiler, result) 3337 asy = compare(code) -> 3338 if await self.run_code(code, result, async_=asy): 3339 return True [... skipping hidden 1 frame] Input In [20], in <cell line: 2>() 1 jax_complex_sqrt_no_select = jax.jit(np_complex_sqrt_no_select) ----> 2 jax_complex_sqrt_no_select(-1) File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/_src/traceback_util.py:143, in api_boundary.<locals>.reraise_with_filtered_traceback(*args, **kwargs) 142 try: --> 143 return fun(*args, **kwargs) 144 except Exception as e: File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/_src/api.py:426, in _cpp_jit.<locals>.cache_miss(*args, **kwargs) 425 flat_fun, out_tree = flatten_fun(f, in_tree) --> 426 out_flat = xla.xla_call( 427 flat_fun, 428 *args_flat, 429 device=device, 430 backend=backend, 431 name=flat_fun.__name__, 432 donated_invars=donated_invars) 433 out_pytree_def = out_tree() File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/core.py:1565, in CallPrimitive.bind(self, fun, *args, **params) 1564 def bind(self, fun, *args, **params): -> 1565 return call_bind(self, fun, *args, **params) File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/core.py:1556, in call_bind(primitive, fun, *args, **params) 1555 with maybe_new_sublevel(top_trace): -> 1556 outs = primitive.process(top_trace, fun, tracers, params) 1557 return map(full_lower, apply_todos(env_trace_todo(), outs)) File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/core.py:1568, in CallPrimitive.process(self, trace, fun, tracers, params) 1567 def process(self, trace, fun, tracers, params): -> 1568 return trace.process_call(self, fun, tracers, params) File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/core.py:609, in EvalTrace.process_call(self, primitive, f, tracers, params) 608 def process_call(self, primitive, f, tracers, params): --> 609 return primitive.impl(f, *tracers, **params) File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/interpreters/xla.py:578, in _xla_call_impl(fun, device, backend, name, donated_invars, *args) 577 def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name, donated_invars): --> 578 compiled_fun = _xla_callable(fun, device, backend, name, donated_invars, 579 *unsafe_map(arg_spec, args)) 580 try: File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/linear_util.py:262, in cache.<locals>.memoized_fun(fun, *args) 261 else: --> 262 ans = call(fun, *args) 263 cache[key] = (ans, fun.stores) File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/interpreters/xla.py:652, in _xla_callable(fun, device, backend, name, donated_invars, *arg_specs) 651 abstract_args, _ = unzip2(arg_specs) --> 652 jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args, transform_name="jit") 653 if any(isinstance(c, core.Tracer) for c in consts): File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/interpreters/partial_eval.py:1209, in trace_to_jaxpr_final(fun, in_avals, transform_name) 1208 main.jaxpr_stack = () # type: ignore -> 1209 jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals) 1210 del fun, main File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/interpreters/partial_eval.py:1188, in trace_to_subjaxpr_dynamic(fun, main, in_avals) 1187 in_tracers = map(trace.new_arg, in_avals) -> 1188 ans = fun.call_wrapped(*in_tracers) 1189 out_tracers = map(trace.full_raise, ans) File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/linear_util.py:166, in WrappedFun.call_wrapped(self, *args, **kwargs) 165 try: --> 166 ans = self.f(*args, **dict(self.params, **kwargs)) 167 except: 168 # Some transformations yield from inside context managers, so we have to 169 # interrupt them before reraising the exception. Otherwise they will only 170 # get garbage-collected at some later time, running their cleanup tasks only 171 # after this exception is handled, which can corrupt the global state. File <lambdifygenerated-3>:2, in _lambdifygenerated(x) 1 def _lambdifygenerated(x): ----> 2 return (((1j*sqrt(-x)) if (x < 0) else (sqrt(x)))) File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/core.py:538, in Tracer.__bool__(self) --> 538 def __bool__(self): return self.aval._bool(self) File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/core.py:960, in concretization_function_error.<locals>.error(self, arg) 959 def error(self, arg): --> 960 raise ConcretizationTypeError(arg, fname_context) UnfilteredStackTrace: jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)> The problem arose with the `bool` function. While tracing the function _lambdifygenerated at <lambdifygenerated-3>:1, transformed by jit., this concrete value was not available in Python because it depends on the value of the arguments to _lambdifygenerated at <lambdifygenerated-3>:1, transformed by jit. at flattened positions [0], and the computation of these values is being staged out (that is, delayed rather than executed eagerly). (https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError) The stack trace below excludes JAX-internal frames. The preceding is the original exception that occurred, unmodified. -------------------- The above exception was the direct cause of the following exception: ConcretizationTypeError Traceback (most recent call last) Input In [20], in <cell line: 2>() 1 jax_complex_sqrt_no_select = jax.jit(np_complex_sqrt_no_select) ----> 2 jax_complex_sqrt_no_select(-1) File <lambdifygenerated-3>:2, in _lambdifygenerated(x) 1 def _lambdifygenerated(x): ----> 2 return (((1j*sqrt(-x)) if (x < 0) else (sqrt(x)))) ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)> The problem arose with the `bool` function. While tracing the function _lambdifygenerated at <lambdifygenerated-3>:1, transformed by jit., this concrete value was not available in Python because it depends on the value of the arguments to _lambdifygenerated at <lambdifygenerated-3>:1, transformed by jit. at flattened positions [0], and the computation of these values is being staged out (that is, delayed rather than executed eagerly). (https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError)
When instead using the {mod}numpy
module (or "numpy"
), {func}~sympy.utilities.lambdify.lambdify
correctly lambdifies to {func}numpy.select
to represent the cases.
np_complex_sqrt = sp.lambdify(x, complex_sqrt(x), np)
source = inspect.getsource(np_complex_sqrt)
print(format_str(source.replace("nan)", "nan,)"), mode=FileMode()))
def _lambdifygenerated(x): return select( [less(x, 0), True], [1j * sqrt(-x), sqrt(x)], default=nan, )
Still, JAX does not handle this correctly. First, lambdifying JAX again results in this if-else
syntax:
jnp_complex_sqrt = sp.lambdify(x, complex_sqrt(x), jnp)
source = inspect.getsource(jnp_complex_sqrt)
print(source)
def _lambdifygenerated(x): return (((1j*sqrt(-x)) if (x < 0) else (sqrt(x))))
But even if we lambdify to {mod}numpy
and decorate the result with a {func}jax.jit
decorator, the resulting function does not work properly:
jax_complex_sqrt_error = jax.jit(np_complex_sqrt)
source = inspect.getsource(jax_complex_sqrt_error)
print(format_str(source.replace("nan)", "nan,)"), mode=FileMode()))
def _lambdifygenerated(x): return select( [less(x, 0), True], [1j * sqrt(-x), sqrt(x)], default=nan, )
jax_complex_sqrt_error(-1)
--------------------------------------------------------------------------- UnfilteredStackTrace Traceback (most recent call last) File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/ipykernel_launcher.py:17, in <module> 15 from ipykernel import kernelapp as app ---> 17 app.launch_new_instance() File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/traitlets/config/application.py:976, in Application.launch_instance(cls, argv, **kwargs) 975 app.initialize(argv) --> 976 app.start() File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/ipykernel/kernelapp.py:712, in IPKernelApp.start(self) 711 try: --> 712 self.io_loop.start() 713 except KeyboardInterrupt: File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/tornado/platform/asyncio.py:199, in BaseAsyncIOLoop.start(self) 198 asyncio.set_event_loop(self.asyncio_loop) --> 199 self.asyncio_loop.run_forever() 200 finally: File ~/miniconda3/envs/compwa-org/lib/python3.8/asyncio/base_events.py:570, in BaseEventLoop.run_forever(self) 569 while True: --> 570 self._run_once() 571 if self._stopping: File ~/miniconda3/envs/compwa-org/lib/python3.8/asyncio/base_events.py:1859, in BaseEventLoop._run_once(self) 1858 else: -> 1859 handle._run() 1860 handle = None File ~/miniconda3/envs/compwa-org/lib/python3.8/asyncio/events.py:81, in Handle._run(self) 80 try: ---> 81 self._context.run(self._callback, *self._args) 82 except (SystemExit, KeyboardInterrupt): File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/ipykernel/kernelbase.py:510, in Kernel.dispatch_queue(self) 509 try: --> 510 await self.process_one() 511 except Exception: File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/ipykernel/kernelbase.py:499, in Kernel.process_one(self, wait) 498 return None --> 499 await dispatch(*args) File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/ipykernel/kernelbase.py:406, in Kernel.dispatch_shell(self, msg) 405 if inspect.isawaitable(result): --> 406 await result 407 except Exception: File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/ipykernel/kernelbase.py:730, in Kernel.execute_request(self, stream, ident, parent) 729 if inspect.isawaitable(reply_content): --> 730 reply_content = await reply_content 732 # Flush output before sending the reply. File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/ipykernel/ipkernel.py:383, in IPythonKernel.do_execute(self, code, silent, store_history, user_expressions, allow_stdin, cell_id) 382 if with_cell_id: --> 383 res = shell.run_cell( 384 code, 385 store_history=store_history, 386 silent=silent, 387 cell_id=cell_id, 388 ) 389 else: File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/ipykernel/zmqshell.py:528, in ZMQInteractiveShell.run_cell(self, *args, **kwargs) 527 self._last_traceback = None --> 528 return super().run_cell(*args, **kwargs) File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/IPython/core/interactiveshell.py:2881, in InteractiveShell.run_cell(self, raw_cell, store_history, silent, shell_futures, cell_id) 2880 try: -> 2881 result = self._run_cell( 2882 raw_cell, store_history, silent, shell_futures, cell_id 2883 ) 2884 finally: File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/IPython/core/interactiveshell.py:2936, in InteractiveShell._run_cell(self, raw_cell, store_history, silent, shell_futures, cell_id) 2935 try: -> 2936 return runner(coro) 2937 except BaseException as e: File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/IPython/core/async_helpers.py:129, in _pseudo_sync_runner(coro) 128 try: --> 129 coro.send(None) 130 except StopIteration as exc: File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/IPython/core/interactiveshell.py:3135, in InteractiveShell.run_cell_async(self, raw_cell, store_history, silent, shell_futures, transformed_cell, preprocessing_exc_tuple, cell_id) 3133 interactivity = "none" if silent else self.ast_node_interactivity -> 3135 has_raised = await self.run_ast_nodes(code_ast.body, cell_name, 3136 interactivity=interactivity, compiler=compiler, result=result) 3138 self.last_execution_succeeded = not has_raised File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/IPython/core/interactiveshell.py:3338, in InteractiveShell.run_ast_nodes(self, nodelist, cell_name, interactivity, compiler, result) 3337 asy = compare(code) -> 3338 if await self.run_code(code, result, async_=asy): 3339 return True [... skipping hidden 1 frame] Input In [26], in <cell line: 1>() ----> 1 jax_complex_sqrt_error(-1) File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/_src/traceback_util.py:143, in api_boundary.<locals>.reraise_with_filtered_traceback(*args, **kwargs) 142 try: --> 143 return fun(*args, **kwargs) 144 except Exception as e: File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/_src/api.py:426, in _cpp_jit.<locals>.cache_miss(*args, **kwargs) 425 flat_fun, out_tree = flatten_fun(f, in_tree) --> 426 out_flat = xla.xla_call( 427 flat_fun, 428 *args_flat, 429 device=device, 430 backend=backend, 431 name=flat_fun.__name__, 432 donated_invars=donated_invars) 433 out_pytree_def = out_tree() File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/core.py:1565, in CallPrimitive.bind(self, fun, *args, **params) 1564 def bind(self, fun, *args, **params): -> 1565 return call_bind(self, fun, *args, **params) File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/core.py:1556, in call_bind(primitive, fun, *args, **params) 1555 with maybe_new_sublevel(top_trace): -> 1556 outs = primitive.process(top_trace, fun, tracers, params) 1557 return map(full_lower, apply_todos(env_trace_todo(), outs)) File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/core.py:1568, in CallPrimitive.process(self, trace, fun, tracers, params) 1567 def process(self, trace, fun, tracers, params): -> 1568 return trace.process_call(self, fun, tracers, params) File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/core.py:609, in EvalTrace.process_call(self, primitive, f, tracers, params) 608 def process_call(self, primitive, f, tracers, params): --> 609 return primitive.impl(f, *tracers, **params) File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/interpreters/xla.py:578, in _xla_call_impl(fun, device, backend, name, donated_invars, *args) 577 def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name, donated_invars): --> 578 compiled_fun = _xla_callable(fun, device, backend, name, donated_invars, 579 *unsafe_map(arg_spec, args)) 580 try: File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/linear_util.py:262, in cache.<locals>.memoized_fun(fun, *args) 261 else: --> 262 ans = call(fun, *args) 263 cache[key] = (ans, fun.stores) File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/interpreters/xla.py:652, in _xla_callable(fun, device, backend, name, donated_invars, *arg_specs) 651 abstract_args, _ = unzip2(arg_specs) --> 652 jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args, transform_name="jit") 653 if any(isinstance(c, core.Tracer) for c in consts): File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/interpreters/partial_eval.py:1209, in trace_to_jaxpr_final(fun, in_avals, transform_name) 1208 main.jaxpr_stack = () # type: ignore -> 1209 jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals) 1210 del fun, main File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/interpreters/partial_eval.py:1188, in trace_to_subjaxpr_dynamic(fun, main, in_avals) 1187 in_tracers = map(trace.new_arg, in_avals) -> 1188 ans = fun.call_wrapped(*in_tracers) 1189 out_tracers = map(trace.full_raise, ans) File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/linear_util.py:166, in WrappedFun.call_wrapped(self, *args, **kwargs) 165 try: --> 166 ans = self.f(*args, **dict(self.params, **kwargs)) 167 except: 168 # Some transformations yield from inside context managers, so we have to 169 # interrupt them before reraising the exception. Otherwise they will only 170 # get garbage-collected at some later time, running their cleanup tasks only 171 # after this exception is handled, which can corrupt the global state. File <lambdifygenerated-4>:2, in _lambdifygenerated(x) 1 def _lambdifygenerated(x): ----> 2 return (select([less(x, 0),True], [1j*sqrt(-x),sqrt(x)], default=nan)) File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/core.py:472, in Tracer.__array__(self, *args, **kw) 471 def __array__(self, *args, **kw): --> 472 raise TracerArrayConversionError(self) UnfilteredStackTrace: jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)> (https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError) The stack trace below excludes JAX-internal frames. The preceding is the original exception that occurred, unmodified. -------------------- The above exception was the direct cause of the following exception: TracerArrayConversionError Traceback (most recent call last) Input In [26], in <cell line: 1>() ----> 1 jax_complex_sqrt_error(-1) File <lambdifygenerated-4>:2, in _lambdifygenerated(x) 1 def _lambdifygenerated(x): ----> 2 return (select([less(x, 0),True], [1j*sqrt(-x),sqrt(x)], default=nan)) TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)> (https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError)
The very same function in created purely with {mod}jax.numpy
does work without problems, so it seems this is a SymPy problem:
@jax.jit
def jax_complex_sqrt(x):
return jnp.select(
[jnp.less(x, 0), True],
[1j * jnp.sqrt(-x), jnp.sqrt(x)],
default=jnp.nan,
)
jax_complex_sqrt(sample)
DeviceArray([0. +1.j , 0. +0.70710677j, 0. +0.j , 0.70710677+0.j , 1. +0.j ], dtype=complex64)
A solution to this is presented in {ref}report/001:Handle for JAX
.