Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jitted Function Save/Load #229

Open
xiazhuo opened this issue Jan 3, 2025 · 5 comments
Open

Jitted Function Save/Load #229

xiazhuo opened this issue Jan 3, 2025 · 5 comments
Labels
bug Something isn't working

Comments

@xiazhuo
Copy link

xiazhuo commented Jan 3, 2025

Issue Description

I am facing an issue when trying to save and load JIT-compiled functions using tensorcircuit.keras.save_func() and tensorcircuit.keras.load_func(). Specifically, I am trying to save and load the qpred (or qlayer) function in my hybrid model, but I encounter the following error when trying to load the function:

  File "/home/.miniconda3/envs/qml/lib/python3.10/site-packages/tensorcircuit/keras.py", line 284, in wrapper  *
    return m.f(*args, **kws)
AttributeError: '_UserObject' object has no attribute 'f'

Here is the code that I am working with:

class HybridModel(torch.nn.Module): 
    def __init__(self, trunk_size, n_layers=2, n_hidden_layers=4, n_wires=2):
        super().__init__()
        K = tc.set_backend("tensorflow")
        tf_device = "/gpu"

        @tf.function
        def qpred(inputs, weights):
            with tf.device(tf_device):
                c = circuit(inputs, weights, trunk_size)
                observables = K.stack([K.real(c.expectation_ps(z=[i]))
                                       for i in range(n_wires)])
                return observables

        self.qpred = qpred
        self.qlayer = tc.TorchLayer(
            self.qpred, weights_shape=[2*n_layers, n_hidden_layers, n_wires, 2], use_jit=True, enable_dlpack=True)
        self.clayer = torch.nn.Linear(n_wires, 1)

    def forward(self, inputs):
        outputs = self.qlayer(inputs)
        outputs = torch.mean(outputs, axis=1)
        return outputs

What I have tried:

  • I have attempted to use tensorcircuit.keras.save_func() and tensorcircuit.keras.load_func() to save and load the function qpred or qlayer, but it results in the above error.

I am wondering if there is a different approach to saving/loading the JIT-compiled function, or if there is a potential issue with the way TensorCircuit handles saved functions in this context.

Would you be able to provide guidance or suggest an alternative solution for saving/loading the function, especially one that involves JIT compilation?

Thank you very much for your time and assistance. I appreciate any help you can provide!

Environment Context

OS info: Linux-5.4.0-150-generic-x86_64-with-glibc2.27
Python version: 3.10.14
Numpy version: 1.26.4
Scipy version: 1.12.0
Pandas version: 2.2.2
TensorNetwork version: 0.5.1
Cotengra version: 0.6.2
TensorFlow version: 2.18.0
TensorFlow GPU: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:1', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:2', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:3', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:4', device_type='GPU')]
TensorFlow CUDA infos: {'cpu_compiler': '/usr/lib/llvm-18/bin/clang', 'cuda_compute_capabilities': ['sm_60', 'sm_70', 'sm_80', 'sm_89', 'compute_90'], 'cuda_version': '12.5.1', 'cudnn_version': '9', 'is_cuda_build': True, 'is_rocm_build': False, 'is_tensorrt_build': False}
Jax version: 0.4.23
Jax GPU: [cuda(id=0), cuda(id=1), cuda(id=2), cuda(id=3), cuda(id=4)]
JaxLib version: 0.4.23
PyTorch version: 2.5.1+cu124
PyTorch GPU support: True
PyTorch GPUs: [<torch.cuda.device object at 0x7fec328bd8d0>, <torch.cuda.device object at 0x7fec328bd900>, <torch.cuda.device object at 0x7fec328bd8a0>, <torch.cuda.device object at 0x7fec328bdc30>, <torch.cuda.device object at 0x7fec328bdc90>]
Pytorch cuda version: 12.4
Cupy is not installed
Qiskit version: 1.3.1
Cirq version: 1.4.1
TensorCircuit version 0.12.0

@xiazhuo xiazhuo added the bug Something isn't working label Jan 3, 2025
@refraction-ray
Copy link
Contributor

refraction-ray commented Jan 5, 2025

work for me to save the tensorflow jitted function at least for cpu.

tf_device = "/cpu"
n_wires = 8

def circuit(inputs, weights):
    c = tc.Circuit(n_wires, inputs=inputs)
    c.rz(range(n_wires), theta=weights)
    return c

@tf.function
def qpred(inputs, weights):
    with tf.device(tf_device):
        c = circuit(inputs, weights)
        observables = K.stack([K.real(c.expectation_ps(x=[i]))
                                for i in range(n_wires)])
        return observables

print(qpred(K.ones([2**n_wires])/K.cast(K.sqrt(2.0**n_wires), "complex64"), 0.3*K.real(K.ones([n_wires]))))

tc.keras.save_func(qpred, "tempsave")
f = tc.keras.load_func("tempsave")
print(f(K.ones([2**n_wires])/K.cast(K.sqrt(2.0**n_wires), "complex64"), 0.3*K.real(K.ones([n_wires]))))

The model instance cannot be saved via tensorflow tools as the model is an instance of torch model.

For jax jitted function, please refer to https://jax.readthedocs.io/en/latest/export/export.html for IO.

Updated: I have also implemented function wrappers in TensorCircuit-NG to save/load jitted jax function: please see https://tensorcircuit-ng.readthedocs.io/en/latest/advance.html#jitted-function-save-load

@xiazhuo
Copy link
Author

xiazhuo commented Jan 19, 2025

Yes, I realized that I should directly save qpred instead of the entire model. Additionally, I encountered an issue when using the TorchLayer interface to wrap qpred, which also led to problems with the load_func. It turns out I should be using TensorFlow’s interface instead. Thank you so much for your enthusiastic help — your guidance has been incredibly valuable!

Now, as I continue testing the advanced automatic differentiation features of TensorCircuit, I’ve encountered a new issue and would appreciate your assistance.

Here is the code that triggers the error:

import tensorcircuit as tc 
import numpy as np

K = tc.set_backend("tensorflow")

def ansatz(thetas, alpha):
    c = tc.Circuit(2)
    for j in range(2):
        for i in range(2):
            c.rx(i, theta=thetas[j])
            c.ry(i, theta=alpha[j])
        for i in range(2 - 1):
            c.cnot(i, i + 1)
    return c

def f(thetas, alpha):
    c = ansatz(thetas, alpha)
    observables = K.stack([K.real(c.expectation_ps(z=[i]))
                                       for i in range(2)])
    return K.mean(observables, axis=0)

f_vmap = K.vmap(f, vectorized_argnums=0)

inputs = K.implicit_randn([3,2])
weights = K.implicit_randn([2])

g1 = K.grad(f_vmap)
grad1 = g1(inputs, weights).numpy()
g2 = K.grad(g1)
grad2 = g2(inputs, weights).numpy()

jac_fun = K.vmap(K.jacfwd(f), vectorized_argnums=0)
jac1 = jac_fun(inputs, weights).numpy()
print(np.isclose(jac1, grad1).all())

hess_f = K.vmap(K.hessian(f), vectorized_argnums=0)
hess_f(inputs, weights)   # Error occurs

When running this code, I encounter the following error:

ERROR:tensorflow:Got error while pfor was converting op name: "gradient_tape/gradient_tape/Einsum"
op: "Einsum"
input: "tangents"
attr {
  key: "equation"
  value {
    s: "abc->bcaa"
  }
}
attr {
  key: "T"
  value {
    type: DT_COMPLEX64
  }
}
attr {
  key: "N"
  value {
    i: 1
  }
}
 with inputs (<tf.Tensor 'tangents:0' shape=(2, 2, 2) dtype=complex64>,)
, converted inputs [WrappedTensor(t=<tf.Tensor 'args_2:0' shape=(2, 2, 2, 2) dtype=complex64>, is_stacked=True, is_sparse_stacked=False)]
Here are the pfor conversion stack traces: Output subscripts contain a label appearing more than once: dabc->dbcaa
ERROR:tensorflow:name: "gradient_tape/gradient_tape/Einsum"
op: "Einsum"
input: "tangents"
attr {
  key: "equation"
  value {
    s: "abc->bcaa"
  }
}
attr {
  key: "T"
  value {
    type: DT_COMPLEX64
  }
}
attr {
  key: "N"
  value {
    i: 1
  }
}
ValueError: Output subscripts contain a label appearing more than once: dabc->dbcaa

It appears that the issue arises when computing the second-order derivative using K.hessian. From the error message, it seems like there's an issue with the way subscripts are being handled in the Einsum operation.

Additionally, I have a question regarding the logic of my code. I am trying to incorporate higher-order derivatives into the loss function for soft constraints. Could you kindly confirm whether the approach in my code will correctly compute the high-order derivatives of the quantum neural network output with respect to the input (not the trainable parameters)?

Thank you very much for your time and assistance. I truly appreciate any insights or suggestions you may have.

@refraction-ray
Copy link
Contributor

TensorCircuit-NG head has fixed the error above, see https://github.com/tensorcircuit/tensorcircuit-ng/blob/master/examples/nested_vmap_grad.py for the example. But note that when vmap is outside grad-like function on tensorflow backend, the numerical results will be incorrect due to a long exisiting bug tensorflow/tensorflow#52148

@xiazhuo
Copy link
Author

xiazhuo commented Jan 21, 2025

Thank you for your patience and assistance. I now understand that to obtain the correct gradient, vmap should be placed inside the grad function. However, while testing the code you provided, I encountered some puzzling behaviors that I am struggling to understand. I would greatly appreciate any explanation you could provide.

Specifically, here is the code I am using:

import tensorcircuit as tc

for backend in ["tensorflow", "jax"]:
    with tc.runtime_backend(backend) as K:
        print(f"\n{backend}")
        L = 2
        inputs = K.cast(K.ones([3, 2]), tc.rdtypestr)
        weights = K.cast(K.ones([2]), tc.rdtypestr)

        def ansatz(thetas, alpha):
            c = tc.Circuit(L)
            for j in range(2):
                for i in range(L):
                    c.rx(i, theta=thetas[j])
                    c.ry(i, theta=alpha[j])
                for i in range(L - 1):
                    c.cnot(i, i + 1)
            return c

        def f(thetas, alpha):
            c = ansatz(thetas, alpha)
            observables = K.stack([K.real(c.expectation_ps(z=[i])) for i in range(L)])
            return K.mean(observables)

        print("vmap", K.vmap(f)(inputs, weights))
        print("jvp", K.jacfwd(f)(inputs[0], weights))
        print("hess", K.diagonal(K.hessian(f)(inputs[0], weights)))
        print("grad1_0", K.grad(f)(inputs[0], weights))
        print("grad1_all", K.grad(K.vmap(f))(inputs, weights))
        print("grad2_0", K.grad(K.grad(f))(inputs[0], weights))
        print("grad2_all", K.grad(K.grad(K.vmap(f)))(inputs, weights))

I expected the following:

  1. The output of jvp should match grad1_0, representing the first-order derivatives of the output with respect to the two components of inputs.
  2. The output of hess should match grad2_0, representing the second-order derivatives of the output with respect to the two components of inputs.
  3. grad1_0 should be the first row of grad1_all.
  4. grad2_0 should be the first row of grad2_all.
  5. These results should be consistent across backends, meaning the results with TensorFlow and JAX should be identical.

However, the actual results I obtained were as follows:

TensorFlow

vmap tf.Tensor([0.2257461 0.2257461 0.2257461], shape=(3,), dtype=float32)
jvp tf.Tensor([ 0.26698774 -0.08074658], shape=(2,), dtype=float32)
hess tf.Tensor([-0.49542376  0.14224827], shape=(2,), dtype=float32)
grad1_0 tf.Tensor([ 0.26698777 -0.08074659], shape=(2,), dtype=float32)
grad1_all tf.Tensor(
[[ 0.26698768 -0.08074658]
 [ 0.26698768 -0.08074658]
 [ 0.26698768 -0.08074658]], shape=(3, 2), dtype=float32)
grad2_0 tf.Tensor([-0.42846823  0.20920378], shape=(2,), dtype=float32)
grad2_all tf.Tensor(
[[-0.42846814  0.20920385]
 [-0.42846814  0.20920385]
 [-0.42846814  0.20920385]], shape=(3, 2), dtype=float32)

JAX

vmap [0.22536246 0.22536246 0.22536246]
jvp [ 0.26677302 -0.08061091]
hess [-0.49530387  0.14196607]
grad1_0 [ 0.26679915 -0.08059981]
TypeError: 
---> 27 print("grad1_all", K.grad(K.vmap(f))(inputs, weights[0]))
Gradient only defined for scalar-output functions. Output had shape: (3,).

This does not match the results I expected. Is there a reasonable explanation for this? I would like to know which method is correct for calculating higher-order derivatives of the output with respect to the components of inputs.

Thank you very much for your help!

@refraction-ray
Copy link
Contributor

K.grad(K.grad(f)): the outside gradient is taken over the objective $(\partial_{inputs_0}+\partial_{inputs_1})f$ (tf behavior, for jax, the implicit sum of outputs is not assumed, that is why jax complains about non scalar output to be differentiuated), the two returned value for K.grad(K.grad corresponds the sum of two rows of the hessian matrix (K.hessian)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants