Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] controlled operation of adjoint doesn't work with jax-jit #6982

Open
1 task done
CatalinaAlbornoz opened this issue Feb 19, 2025 · 0 comments
Open
1 task done
Labels
bug 🐛 Something isn't working

Comments

@CatalinaAlbornoz
Copy link
Contributor

Expected behavior

You can use jax-jit on a circuit that contains a controlled operation of an adjoint.

Actual behavior

It throws a TracerBoolConversionError.

Additional information

This was originally surfaced in Forum thread 7977.

Just using qml.ctrl or just using qml.adjoint works fine, but the combination of both doesn't work.

Source code

import pennylane as qml
from jax import numpy as jnp
import jax

# create the device
device = "lightning.qubit"
shots = 100
n_wires = 3
dev = qml.device(device, wires=n_wires, shots=shots)

# create the function
def func_circ(weights,qubit):
    qml.RY(weights, wires=qubit)

# find the adjoint
ad = qml.adjoint(func_circ)

# jax-jit the circuit (without jitting it works fine)
@jax.jit
@qml.qnode(device=dev)
def cost_circuit(weights):
    # use a controlled operation on the adjoint
    qml.ctrl(ad, control=(0))(weights[0], 2)
    return qml.expval(qml.PauliZ(0))

# run the circuit for some weights
weights = jnp.array([0.,0.,0.])
print(cost_circuit(weights))

Tracebacks

---------------------------------------------------------------------------
TracerBoolConversionError                 Traceback (most recent call last)
<ipython-input-3-db41d48b13f5> in <cell line: 0>()
     27 # run the circuit for some weights
     28 weights = jnp.array([0.,0.,0.])
---> 29 print(cost_circuit(weights))

    [... skipping hidden 11 frame]

10 frames
/usr/local/lib/python3.11/dist-packages/pennylane/workflow/qnode.py in __call__(self, *args, **kwargs)
    903         if qml.capture.enabled():
    904             return capture_qnode(self, *args, **kwargs)
--> 905         return self._impl_call(*args, **kwargs)
    906 
    907 

/usr/local/lib/python3.11/dist-packages/pennylane/workflow/qnode.py in _impl_call(self, *args, **kwargs)
    879         self._transform_program.set_classical_component(self, args, kwargs)
    880 
--> 881         res = qml.execute(
    882             (tape,),
    883             device=self.device,

/usr/local/lib/python3.11/dist-packages/pennylane/workflow/execution.py in execute(tapes, device, diff_method, interface, transform_program, inner_transform, config, grad_on_execution, gradient_kwargs, cache, cachesize, max_diff, device_vjp, mcm_config, gradient_fn)
    225 
    226     #### Executing the configured setup #####
--> 227     tapes, post_processing = transform_program(tapes)
    228 
    229     if transform_program.is_informative:

/usr/local/lib/python3.11/dist-packages/pennylane/transforms/core/transform_program.py in __call__(self, tapes)
    578                 if argnums is not None:
    579                     tape.trainable_params = argnums[j]
--> 580                 new_tapes, fn = transform(tape, *targs, **tkwargs)
    581                 execution_tapes.extend(new_tapes)
    582 

/usr/local/lib/python3.11/dist-packages/pennylane/gradients/parameter_shift.py in _expand_transform_param_shift(tape, argnum, shifts, gradient_recipes, fallback_fn, f0, broadcast)
    763 ) -> tuple[QuantumScriptBatch, PostprocessingFn]:
    764     """Expand function to be applied before parameter shift."""
--> 765     [new_tape], postprocessing = qml.devices.preprocess.decompose(
    766         tape,
    767         stopping_condition=_param_shift_stopping_condition,

/usr/local/lib/python3.11/dist-packages/pennylane/transforms/core/transform_dispatcher.py in __call__(self, *targs, **tkwargs)
    151 
    152             else:
--> 153                 transformed_tapes, processing_fn = self._transform(obj, *targs, **tkwargs)
    154 
    155             if self.is_informative:

/usr/local/lib/python3.11/dist-packages/pennylane/devices/preprocess.py in decompose(tape, stopping_condition, stopping_condition_shots, skip_initial_state_prep, decomposer, max_expansion, name, error)
    406         prep_op = []
    407 
--> 408     if all(stopping_condition(op) for op in tape.operations[len(prep_op) :]):
    409         return (tape,), null_postprocessing
    410     try:

/usr/local/lib/python3.11/dist-packages/pennylane/devices/preprocess.py in <genexpr>(.0)
    406         prep_op = []
    407 
--> 408     if all(stopping_condition(op) for op in tape.operations[len(prep_op) :]):
    409         return (tape,), null_postprocessing
    410     try:

/usr/local/lib/python3.11/dist-packages/pennylane/gradients/parameter_shift.py in _param_shift_stopping_condition(op)
    744 
    745 def _param_shift_stopping_condition(op) -> bool:
--> 746     if not op.has_decomposition:
    747         # let things without decompositions through without error
    748         # error will happen when calculating parameter shift tapes

/usr/local/lib/python3.11/dist-packages/pennylane/ops/op_math/controlled.py in has_decomposition(self)
    719         if len(self.control_wires) == 1 and hasattr(self.base, "_controlled"):
    720             return True
--> 721         if _is_single_qubit_special_unitary(self.base):
    722             return True
    723         if self.base.has_decomposition:

    [... skipping hidden 1 frame]

/usr/local/lib/python3.11/dist-packages/jax/_src/core.py in error(self, arg)
   1536   if fun is bool:
   1537     def error(self, arg):
-> 1538       raise TracerBoolConversionError(arg)
   1539   elif fun in (hex, oct, operator.index):
   1540     def error(self, arg):

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function cost_circuit at <ipython-input-3-db41d48b13f5>:20 for jit. This concrete value was not available in Python because it depends on the value of the argument weights.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

System information

Name: PennyLane
Version: 0.40.0
Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Home-page: https://github.com/PennyLaneAI/pennylane
Author: 
Author-email: 
License: Apache License 2.0
Location: /usr/local/lib/python3.11/dist-packages
Requires: appdirs, autograd, autoray, cachetools, diastatic-malt, networkx, numpy, packaging, pennylane-lightning, requests, rustworkx, scipy, tomlkit, typing-extensions
Required-by: PennyLane_Lightning

Platform info:           Linux-6.1.85+-x86_64-with-glibc2.35
Python version:          3.11.11
Numpy version:           1.26.4
Scipy version:           1.13.1
Installed devices:
- lightning.qubit (PennyLane_Lightning-0.40.0)
- default.clifford (PennyLane-0.40.0)
- default.gaussian (PennyLane-0.40.0)
- default.mixed (PennyLane-0.40.0)
- default.qubit (PennyLane-0.40.0)
- default.qutrit (PennyLane-0.40.0)
- default.qutrit.mixed (PennyLane-0.40.0)
- default.tensor (PennyLane-0.40.0)
- null.qubit (PennyLane-0.40.0)
- reference.qubit (PennyLane-0.40.0)

Existing GitHub issues

  • I have searched existing GitHub issues to make sure the issue does not already exist.
@CatalinaAlbornoz CatalinaAlbornoz added the bug 🐛 Something isn't working label Feb 19, 2025
@CatalinaAlbornoz CatalinaAlbornoz changed the title [BUG] controlled oprtation of adjoint doesn't work with jax-jit [BUG] controlled operation of adjoint doesn't work with jax-jit Feb 19, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐛 Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant