|
35 | 35 | from discopy.tensor import Diagram
|
36 | 36 | from sympy import default_sort_key, lambdify
|
37 | 37 |
|
| 38 | +from lambeq.training.model import SizedIterable |
38 | 39 | from lambeq.training.quantum_model import QuantumModel
|
39 | 40 |
|
40 | 41 |
|
@@ -100,13 +101,40 @@ def _get_lambda(self, diagram: Diagram) -> Callable[[Any], Any]:
|
100 | 101 | return self.lambdas[diagram]
|
101 | 102 |
|
102 | 103 | def diagram_output(*x):
|
103 |
| - with Tensor.backend('jax'): |
104 |
| - result = diagram.lambdify(*self.symbols)(*x).eval().array |
| 104 | + with Tensor.backend('jax'), tn.DefaultBackend('jax'): |
| 105 | + sub_circuit = self._fast_subs([diagram], x)[0] |
| 106 | + result = tn.contractors.auto(*sub_circuit.to_tn()).tensor |
105 | 107 | return self._normalise_vector(result)
|
106 | 108 |
|
107 | 109 | self.lambdas[diagram] = jit(diagram_output)
|
108 | 110 | return self.lambdas[diagram]
|
109 | 111 |
|
| 112 | + def _fast_subs(self, |
| 113 | + diagrams: list[Diagram], |
| 114 | + weights: SizedIterable) -> list[Diagram]: |
| 115 | + """Substitute weights into a list of parameterised circuit.""" |
| 116 | + parameters = {k: v for k, v in zip(self.symbols, weights)} |
| 117 | + diagrams = pickle.loads(pickle.dumps(diagrams)) # does fast deepcopy |
| 118 | + for diagram in diagrams: |
| 119 | + for b in diagram._boxes: |
| 120 | + if b.free_symbols: |
| 121 | + while hasattr(b, 'controlled'): |
| 122 | + b._free_symbols = set() |
| 123 | + b = b.controlled |
| 124 | + syms, values = [], [] |
| 125 | + for sym in b._free_symbols: |
| 126 | + syms.append(sym) |
| 127 | + try: |
| 128 | + values.append(parameters[sym]) |
| 129 | + except KeyError: |
| 130 | + raise KeyError(f'Unknown symbol {sym!r}.') |
| 131 | + b._data = lambdify(syms, b._data)(*values) |
| 132 | + b.drawing_name = b.name |
| 133 | + b._free_symbols = set() |
| 134 | + if hasattr(b, '_phase'): |
| 135 | + b._phase = b._data |
| 136 | + return diagrams |
| 137 | + |
110 | 138 | def get_diagram_output(self, diagrams: list[Diagram]) -> numpy.ndarray:
|
111 | 139 | """Return the exact prediction for each diagram.
|
112 | 140 |
|
@@ -139,27 +167,7 @@ def get_diagram_output(self, diagrams: list[Diagram]) -> numpy.ndarray:
|
139 | 167 | return numpy.array([diag_f(*self.weights)
|
140 | 168 | for diag_f in lambdified_diagrams])
|
141 | 169 |
|
142 |
| - parameters = {k: v for k, v in zip(self.symbols, self.weights)} |
143 |
| - diagrams = pickle.loads(pickle.dumps(diagrams)) # does fast deepcopy |
144 |
| - for diagram in diagrams: |
145 |
| - for b in diagram._boxes: |
146 |
| - if b.free_symbols: |
147 |
| - while hasattr(b, 'controlled'): |
148 |
| - b._free_symbols = set() |
149 |
| - b = b.controlled |
150 |
| - syms, values = [], [] |
151 |
| - for sym in b._free_symbols: |
152 |
| - syms.append(sym) |
153 |
| - try: |
154 |
| - values.append(parameters[sym]) |
155 |
| - except KeyError: |
156 |
| - raise KeyError(f'Unknown symbol {sym!r}.') |
157 |
| - b._data = lambdify(syms, b._data)(*values) |
158 |
| - b.drawing_name = b.name |
159 |
| - b._free_symbols = set() |
160 |
| - if hasattr(b, '_phase'): |
161 |
| - b._phase = b._data |
162 |
| - |
| 170 | + diagrams = self._fast_subs(diagrams, self.weights) |
163 | 171 | with Tensor.backend('numpy'):
|
164 | 172 | return numpy.array([
|
165 | 173 | self._normalise_vector(tn.contractors.auto(*d.to_tn()).tensor)
|
|
0 commit comments