Skip to content

Commit 70a1fe8

Browse files
ianyfanThomas Hoffmannnikhilkhatri
committed
Release version 0.2.8
Co-authored-by: Thomas Hoffmann <thomas.hoffmann@cambridgequantum.com> Co-authored-by: Nikhil Khatri <nikhil.khatri@quantinuum.com>
1 parent c4e361a commit 70a1fe8

14 files changed

+70
-40
lines changed

.github/workflows/build_test.yml

+8-19
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ jobs:
2121
outputs:
2222
error-check: ${{ steps.error-check.conclusion }}
2323
steps:
24-
- uses: actions/checkout@v2
24+
- uses: actions/checkout@v3
2525
- name: Setup Python ${{ matrix.python-version }}
26-
uses: actions/setup-python@v2
26+
uses: actions/setup-python@v4
2727
with:
2828
python-version: ${{ matrix.python-version }}
2929
- name: Install linter
@@ -49,25 +49,14 @@ jobs:
4949
matrix:
5050
python-version: [ 3.8, 3.9, "3.10" ]
5151
steps:
52-
- uses: actions/checkout@v2
52+
- uses: actions/checkout@v3
5353
- name: Setup Python ${{ matrix.python-version }}
54-
uses: actions/setup-python@v2
54+
uses: actions/setup-python@v4
5555
with:
5656
python-version: ${{ matrix.python-version }}
57-
- name: Locate pip cache
58-
id: loc-pip-cache
59-
run: echo "::set-output name=dir::$(pip cache dir)"
60-
- name: Restore pip dependencies from cache
61-
uses: actions/cache@v2
62-
with:
63-
path: ${{ steps.loc-pip-cache.outputs.dir }}
64-
key: build_and_test-${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('setup.cfg') }}
65-
restore-keys: |
66-
build_and_test-${{ runner.os }}-pip-${{ matrix.python-version }}-
67-
build_and_test-${{ runner.os }}-pip-
68-
- name: Install DisCoPy from GitHub
57+
- name: Install DisCoPy 0.5 from GitHub
6958
if: github.ref_name != 'release' && github.ref_name != 'beta'
70-
run: pip install git+https://github.com/oxford-quantum-group/discopy
59+
run: pip install git+https://github.com/discopy/discopy@0.5
7160
- name: Install base package
7261
run: pip install .
7362
- name: Check package import works
@@ -130,9 +119,9 @@ jobs:
130119
matrix:
131120
python-version: [ 3.8, 3.9, "3.10" ]
132121
steps:
133-
- uses: actions/checkout@v2
122+
- uses: actions/checkout@v3
134123
- name: Setup Python ${{ matrix.python-version }}
135-
uses: actions/setup-python@v2
124+
uses: actions/setup-python@v4
136125
with:
137126
python-version: ${{ matrix.python-version }}
138127
- name: Install dependencies with type hints

.github/workflows/docs.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@ jobs:
1616
name: Build and deploy documentation
1717
runs-on: ubuntu-latest
1818
steps:
19-
- uses: actions/checkout@v2
19+
- uses: actions/checkout@v3
2020
with:
2121
fetch-depth: 0 # fetches tags, required for version info
2222
- name: Set up Python
23-
uses: actions/setup-python@v2
23+
uses: actions/setup-python@v4
2424
with:
2525
python-version: 3.8
2626
- name: Build lambeq

docs/conf.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
]
4747

4848
intersphinx_mapping = {
49-
'discopy': ("https://discopy.readthedocs.io/en/main/", None),
49+
'discopy': ("https://discopy.readthedocs.io/en/0.5/", None),
5050
'pennylane': ("https://pennylane.readthedocs.io/en/stable/", None),
5151
}
5252

docs/release_notes.rst

+14
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,20 @@
33
Release notes
44
=============
55

6+
.. _rel-0.2.8:
7+
8+
`0.2.8 <https://github.com/CQCL/lambeq/releases/tag/0.2.8>`_
9+
------------------------------------------------------------
10+
Changed:
11+
12+
- Improved the performance of :py:class:`.NumpyModel` when using Jax JIT-compilation.
13+
- Dependencies: pinned the required version of DisCoPy to 0.5.X.
14+
15+
Fixed:
16+
17+
- Fixed incorrectly scaled validation loss in progress bar during model training.
18+
- Fixed symbol type mismatch in the quantum models when a circuit was previously converted to tket.
19+
620
.. _rel-0.2.7:
721

822
`0.2.7 <https://github.com/CQCL/lambeq/releases/tag/0.2.7>`_

lambeq/__init__.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,10 @@
1717
'__version_info__',
1818

1919
'ansatz',
20-
'text2diagram',
2120
'core',
2221
'pregroups',
23-
'reader',
2422
'rewrite',
23+
'text2diagram',
2524
'tokeniser',
2625
'training',
2726

lambeq/ansatz/circuit.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@
3434
from discopy.quantum.gates import Bra, H, Ket, Rx, Ry, Rz
3535
from discopy.rigid import Box, Diagram, Ty
3636
import numpy as np
37-
from sympy import symbols
37+
from sympy import Symbol, symbols
3838

39-
from lambeq.ansatz import BaseAnsatz, Symbol
39+
from lambeq.ansatz import BaseAnsatz
4040

4141
computational_basis = Id(qubit)
4242

lambeq/text2diagram/ccg_parser.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from typing import Optional
2323

2424
from discopy import Diagram
25-
from tqdm.autonotebook import tqdm
25+
from tqdm.auto import tqdm
2626

2727
from lambeq.core.globals import VerbosityLevel
2828
from lambeq.core.utils import (SentenceBatchType, SentenceType,

lambeq/training/model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from typing import Any, Union
2727

2828
from discopy.tensor import Diagram
29-
from sympy import default_sort_key
29+
from sympy import default_sort_key, Symbol as SymPySymbol
3030

3131
from lambeq.ansatz.base import Symbol
3232
from lambeq.training.checkpoint import Checkpoint
@@ -50,7 +50,7 @@ class Model(ABC):
5050

5151
def __init__(self) -> None:
5252
"""Initialise an instance of :py:class:`Model` base class."""
53-
self.symbols: list[Symbol] = []
53+
self.symbols: list[Union[Symbol, SymPySymbol]] = []
5454
self.weights: Collection = []
5555

5656
def __call__(self, *args: Any, **kwds: Any) -> Any:

lambeq/training/numpy_model.py

+16-5
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,19 @@
2828

2929
from collections.abc import Callable, Iterable
3030
import pickle
31-
from typing import Any
31+
from typing import Any, TYPE_CHECKING, Union
3232

3333
from discopy import Tensor
3434
from discopy.tensor import Diagram
3535
import numpy
3636
from numpy.typing import ArrayLike
3737
from sympy import lambdify
3838

39+
40+
if TYPE_CHECKING:
41+
from jax import numpy as jnp
42+
43+
3944
from lambeq.training.quantum_model import QuantumModel
4045

4146

@@ -74,7 +79,7 @@ def _get_lambda(self, diagram: Diagram) -> Callable[[Any], Any]:
7479
if diagram in self.lambdas:
7580
return self.lambdas[diagram]
7681

77-
def diagram_output(*x: ArrayLike) -> ArrayLike:
82+
def diagram_output(x: Iterable[ArrayLike]) -> ArrayLike:
7883
with Tensor.backend('jax'), tn.DefaultBackend('jax'):
7984
sub_circuit = self._fast_subs([diagram], x)[0]
8085
result = tn.contractors.auto(*sub_circuit.to_tn()).tensor
@@ -112,7 +117,9 @@ def _fast_subs(self,
112117
b._phase = b._data
113118
return diagrams
114119

115-
def get_diagram_output(self, diagrams: list[Diagram]) -> numpy.ndarray:
120+
def get_diagram_output(self,
121+
diagrams: list[Diagram]) -> Union[jnp.ndarray,
122+
numpy.ndarray]:
116123
"""Return the exact prediction for each diagram.
117124
118125
Parameters
@@ -142,9 +149,13 @@ def get_diagram_output(self, diagrams: list[Diagram]) -> numpy.ndarray:
142149
'from pre-trained checkpoint.')
143150

144151
if self.use_jit:
152+
from jax import numpy as jnp
153+
145154
lambdified_diagrams = [self._get_lambda(d) for d in diagrams]
146-
return numpy.array([diag_f(*self.weights)
147-
for diag_f in lambdified_diagrams])
155+
res: jnp.ndarray = jnp.array([diag_f(self.weights)
156+
for diag_f in lambdified_diagrams])
157+
158+
return res
148159

149160
diagrams = self._fast_subs(diagrams, self.weights)
150161
with Tensor.backend('numpy'):

lambeq/training/pytorch_model.py

+1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ class PytorchModel(Model, torch.nn.Module):
3535
"""A lambeq model for the classical pipeline using PyTorch."""
3636

3737
weights: torch.nn.ParameterList # type: ignore[assignment]
38+
symbols: list[Symbol] # type: ignore[assignment]
3839

3940
def __init__(self) -> None:
4041
"""Initialise a PytorchModel."""

lambeq/training/quantum_model.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,16 @@
2121
from __future__ import annotations
2222

2323
from abc import abstractmethod
24-
from typing import Any
24+
from typing import Any, TYPE_CHECKING, Union
2525

2626
from discopy.tensor import Diagram, Tensor
2727
import numpy as np
2828

29+
30+
if TYPE_CHECKING:
31+
from jax import numpy as jnp
32+
33+
2934
from lambeq.training.checkpoint import Checkpoint
3035
from lambeq.training.model import Model
3136

@@ -90,7 +95,6 @@ def initialise_weights(self) -> None:
9095
if not self.symbols:
9196
raise ValueError('Symbols not initialised. Instantiate through '
9297
'`from_diagrams()`.')
93-
assert all(w.size == 1 for w in self.symbols)
9498
self.weights = np.random.rand(len(self.symbols))
9599

96100
def _load_checkpoint(self, checkpoint: Checkpoint) -> None:
@@ -124,7 +128,8 @@ def _make_checkpoint(self) -> Checkpoint:
124128
return checkpoint
125129

126130
@abstractmethod
127-
def get_diagram_output(self, diagrams: list[Diagram]) -> np.ndarray:
131+
def get_diagram_output(self, diagrams: list[Diagram]) -> Union[jnp.ndarray,
132+
np.ndarray]:
128133
"""Return the diagram prediction.
129134
130135
Parameters

lambeq/training/trainer.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,7 @@ def writer_helper(*args: Any) -> None:
411411
if val_dataset is not None:
412412
if epoch % evaluation_step == 0:
413413
val_loss = 0.0
414+
seen_so_far = 0
414415
batches_per_validation = ceil(len(val_dataset)
415416
/ val_dataset.batch_size)
416417
with Tensor.backend(self.backend):
@@ -425,6 +426,7 @@ def writer_helper(*args: Any) -> None:
425426
x_val, y_label_val = v_batch
426427
y_hat_val, cur_loss = self.validation_step(v_batch)
427428
val_loss += cur_loss * len(x_val)
429+
seen_so_far += len(x_val)
428430
if self.evaluate_functions is not None:
429431
for metr, func in (
430432
self.evaluate_functions.items()):
@@ -434,7 +436,7 @@ def writer_helper(*args: Any) -> None:
434436
status_bar.set_description(
435437
self._generate_stat_report(
436438
train_loss=train_loss,
437-
val_loss=val_loss))
439+
val_loss=val_loss/seen_so_far))
438440
val_loss /= len(val_dataset)
439441
self.val_costs.append(val_loss)
440442
status_bar.set_description(

setup.cfg

+1-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ packages =
5252
lambeq.tokeniser
5353
lambeq.training
5454
install_requires =
55-
discopy >= 0.4.3
55+
discopy ~= 0.5.0
5656
pytket >= 0.19.2
5757
pyyaml
5858
spacy >= 3.0

tests/test_circuit.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
from discopy.quantum import (Bra, CRz, CRx, CX, X, H, Ket,
55
qubit, Rx, Ry, Rz, sqrt, Controlled)
66
from discopy.quantum.circuit import Circuit, Id
7+
from discopy.quantum.tk import from_tk
8+
from sympy import Symbol as sym
79

810
from lambeq import (AtomicType, IQPAnsatz, Sim14Ansatz, Sim15Ansatz,
911
StronglyEntanglingAnsatz)
10-
from lambeq import Symbol as sym
1112

1213
N = AtomicType.NOUN
1314
S = AtomicType.SENTENCE
@@ -255,3 +256,11 @@ def test_strongly_entangling_ansatz_ranges_error2():
255256
with pytest.raises(ValueError):
256257
ansatz = StronglyEntanglingAnsatz({q: 2}, 3, ranges=[1, 1, 2])
257258
ansatz(box)
259+
260+
def test_discopy_tket_conversion():
261+
word1, word2 = Word('Alice', N), Word('Bob', N.r)
262+
sentence = word1 @ word2 >> Cup(N, N.r)
263+
ansatz = IQPAnsatz({N: 1}, n_layers=1)
264+
circuit = ansatz(sentence)
265+
circuit_converted = from_tk(circuit.to_tk())
266+
assert circuit.free_symbols == circuit_converted.free_symbols

0 commit comments

Comments
 (0)