-
Notifications
You must be signed in to change notification settings - Fork 81
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Jitted Function Save/Load #229
Comments
work for me to save the tensorflow jitted function at least for cpu. tf_device = "/cpu"
n_wires = 8
def circuit(inputs, weights):
c = tc.Circuit(n_wires, inputs=inputs)
c.rz(range(n_wires), theta=weights)
return c
@tf.function
def qpred(inputs, weights):
with tf.device(tf_device):
c = circuit(inputs, weights)
observables = K.stack([K.real(c.expectation_ps(x=[i]))
for i in range(n_wires)])
return observables
print(qpred(K.ones([2**n_wires])/K.cast(K.sqrt(2.0**n_wires), "complex64"), 0.3*K.real(K.ones([n_wires]))))
tc.keras.save_func(qpred, "tempsave")
f = tc.keras.load_func("tempsave")
print(f(K.ones([2**n_wires])/K.cast(K.sqrt(2.0**n_wires), "complex64"), 0.3*K.real(K.ones([n_wires])))) The model instance cannot be saved via tensorflow tools as the model is an instance of torch model. For jax jitted function, please refer to https://jax.readthedocs.io/en/latest/export/export.html for IO. Updated: I have also implemented function wrappers in TensorCircuit-NG to save/load jitted jax function: please see https://tensorcircuit-ng.readthedocs.io/en/latest/advance.html#jitted-function-save-load |
Yes, I realized that I should directly save Now, as I continue testing the advanced automatic differentiation features of TensorCircuit, I’ve encountered a new issue and would appreciate your assistance. Here is the code that triggers the error: import tensorcircuit as tc
import numpy as np
K = tc.set_backend("tensorflow")
def ansatz(thetas, alpha):
c = tc.Circuit(2)
for j in range(2):
for i in range(2):
c.rx(i, theta=thetas[j])
c.ry(i, theta=alpha[j])
for i in range(2 - 1):
c.cnot(i, i + 1)
return c
def f(thetas, alpha):
c = ansatz(thetas, alpha)
observables = K.stack([K.real(c.expectation_ps(z=[i]))
for i in range(2)])
return K.mean(observables, axis=0)
f_vmap = K.vmap(f, vectorized_argnums=0)
inputs = K.implicit_randn([3,2])
weights = K.implicit_randn([2])
g1 = K.grad(f_vmap)
grad1 = g1(inputs, weights).numpy()
g2 = K.grad(g1)
grad2 = g2(inputs, weights).numpy()
jac_fun = K.vmap(K.jacfwd(f), vectorized_argnums=0)
jac1 = jac_fun(inputs, weights).numpy()
print(np.isclose(jac1, grad1).all())
hess_f = K.vmap(K.hessian(f), vectorized_argnums=0)
hess_f(inputs, weights) # Error occurs When running this code, I encounter the following error:
It appears that the issue arises when computing the second-order derivative using Additionally, I have a question regarding the logic of my code. I am trying to incorporate higher-order derivatives into the loss function for soft constraints. Could you kindly confirm whether the approach in my code will correctly compute the high-order derivatives of the quantum neural network output with respect to the input (not the trainable parameters)? Thank you very much for your time and assistance. I truly appreciate any insights or suggestions you may have. |
TensorCircuit-NG head has fixed the error above, see https://github.com/tensorcircuit/tensorcircuit-ng/blob/master/examples/nested_vmap_grad.py for the example. But note that when vmap is outside grad-like function on tensorflow backend, the numerical results will be incorrect due to a long exisiting bug tensorflow/tensorflow#52148 |
Thank you for your patience and assistance. I now understand that to obtain the correct gradient, Specifically, here is the code I am using: import tensorcircuit as tc
for backend in ["tensorflow", "jax"]:
with tc.runtime_backend(backend) as K:
print(f"\n{backend}")
L = 2
inputs = K.cast(K.ones([3, 2]), tc.rdtypestr)
weights = K.cast(K.ones([2]), tc.rdtypestr)
def ansatz(thetas, alpha):
c = tc.Circuit(L)
for j in range(2):
for i in range(L):
c.rx(i, theta=thetas[j])
c.ry(i, theta=alpha[j])
for i in range(L - 1):
c.cnot(i, i + 1)
return c
def f(thetas, alpha):
c = ansatz(thetas, alpha)
observables = K.stack([K.real(c.expectation_ps(z=[i])) for i in range(L)])
return K.mean(observables)
print("vmap", K.vmap(f)(inputs, weights))
print("jvp", K.jacfwd(f)(inputs[0], weights))
print("hess", K.diagonal(K.hessian(f)(inputs[0], weights)))
print("grad1_0", K.grad(f)(inputs[0], weights))
print("grad1_all", K.grad(K.vmap(f))(inputs, weights))
print("grad2_0", K.grad(K.grad(f))(inputs[0], weights))
print("grad2_all", K.grad(K.grad(K.vmap(f)))(inputs, weights)) I expected the following:
However, the actual results I obtained were as follows: TensorFlow
JAX
This does not match the results I expected. Is there a reasonable explanation for this? I would like to know which method is correct for calculating higher-order derivatives of the output with respect to the components of Thank you very much for your help! |
|
Issue Description
I am facing an issue when trying to save and load JIT-compiled functions using
tensorcircuit.keras.save_func()
andtensorcircuit.keras.load_func()
. Specifically, I am trying to save and load theqpred
(orqlayer
) function in my hybrid model, but I encounter the following error when trying to load the function:Here is the code that I am working with:
What I have tried:
tensorcircuit.keras.save_func()
andtensorcircuit.keras.load_func()
to save and load the function qpred or qlayer, but it results in the above error.I am wondering if there is a different approach to saving/loading the JIT-compiled function, or if there is a potential issue with the way TensorCircuit handles saved functions in this context.
Would you be able to provide guidance or suggest an alternative solution for saving/loading the function, especially one that involves JIT compilation?
Thank you very much for your time and assistance. I appreciate any help you can provide!
Environment Context
OS info: Linux-5.4.0-150-generic-x86_64-with-glibc2.27
Python version: 3.10.14
Numpy version: 1.26.4
Scipy version: 1.12.0
Pandas version: 2.2.2
TensorNetwork version: 0.5.1
Cotengra version: 0.6.2
TensorFlow version: 2.18.0
TensorFlow GPU: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:1', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:2', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:3', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:4', device_type='GPU')]
TensorFlow CUDA infos: {'cpu_compiler': '/usr/lib/llvm-18/bin/clang', 'cuda_compute_capabilities': ['sm_60', 'sm_70', 'sm_80', 'sm_89', 'compute_90'], 'cuda_version': '12.5.1', 'cudnn_version': '9', 'is_cuda_build': True, 'is_rocm_build': False, 'is_tensorrt_build': False}
Jax version: 0.4.23
Jax GPU: [cuda(id=0), cuda(id=1), cuda(id=2), cuda(id=3), cuda(id=4)]
JaxLib version: 0.4.23
PyTorch version: 2.5.1+cu124
PyTorch GPU support: True
PyTorch GPUs: [<torch.cuda.device object at 0x7fec328bd8d0>, <torch.cuda.device object at 0x7fec328bd900>, <torch.cuda.device object at 0x7fec328bd8a0>, <torch.cuda.device object at 0x7fec328bdc30>, <torch.cuda.device object at 0x7fec328bdc90>]
Pytorch cuda version: 12.4
Cupy is not installed
Qiskit version: 1.3.1
Cirq version: 1.4.1
TensorCircuit version 0.12.0
The text was updated successfully, but these errors were encountered: