Skip to content

Commit 7416d3a

Browse files
Thomas HoffmannianyfanCharles Londony-richie-y
committed
Release version 0.2.4
Co-authored-by: Ian Fan <ian.fan@cambridgequantum.com> Co-authored-by: Charles London <charles.london@cambridgequantum.com> Co-authored-by: Richie Yeung <richie.yeung@cambridgequantum.com>
1 parent 04e4f73 commit 7416d3a

File tree

11 files changed

+107
-98
lines changed

11 files changed

+107
-98
lines changed

.github/workflows/docs.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ jobs:
3434
- name: Move install script
3535
run: mv install.sh docs/_build/html
3636
- name: Deploy documentation
37-
if: ${{ github.event_name == 'push' && github.ref_name == 'main' }}
37+
if: ${{ github.event_name == 'push' && (github.ref_name == 'main' || github.ref_name == 'release') }}
3838
uses: s0/git-publish-subdir-action@develop
3939
env:
4040
REPO: self

README.md

-14
Original file line numberDiff line numberDiff line change
@@ -92,20 +92,6 @@ make html
9292
```
9393
the docs will be under `docs/_build`.
9494

95-
## Known issues
96-
97-
When using lambeq on a Windows machine, the instantiation of the BobcatParser
98-
might trigger an SSL certificate error. We are currently investigating the
99-
issue. In the meantime, you can download the model through this
100-
[link](https://qnlp.cambridgequantum.com/models/bert/latest/model.tar.gz),
101-
extract the archive, and provide the path to the BobcatParser:
102-
103-
```python
104-
from lambeq import BobcatParser
105-
106-
parser = BobcatParser('path/to/model_dir')
107-
```
108-
10995
## License
11096

11197
Distributed under the Apache 2.0 license. See [`LICENSE`](LICENSE) for

docs/release_notes.rst

+11
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,17 @@
33
Release notes
44
=============
55

6+
.. _rel-0.2.4:
7+
8+
`0.2.4 <https://github.com/CQCL/lambeq/releases/tag/0.2.4>`_
9+
------------------------------------------------------------
10+
11+
- Fix a bug that caused the :py:class:`~lambeq.BobcatParser` and the :py:class:`~lambeq.WebParser` to trigger an SSL certificate error using Windows.
12+
13+
- Fix false positives in assigning conjunction rule using the :py:class:`~lambeq.CCGBankParser`. The rule ``, + X[conj] -> X[conj]`` is a case of removing left punctuation, but was being assigned conjunction erroneously.
14+
15+
- Add support for using ``jax`` as backend of ``tensornetwork`` when setting ``use_jit=True`` in the :py:class:`~lambeq.NumpyModel`. The interface is not affected by this change, but performance of the model is significantly improved.
16+
617
.. _rel-0.2.3:
718

819
`0.2.3 <https://github.com/CQCL/lambeq/releases/tag/0.2.3>`_

docs/troubleshooting.rst

+14-4
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,24 @@ encourage you to
1111
SSL error [Windows]
1212
-------------------
1313

14-
When using ``lambeq`` on a Windows machine, the instantiation of the
15-
BobcatParser might trigger an SSL certificate error. We are currently
16-
investigating the issue. In the meantime, you can download the model through
17-
this
14+
When using ``lambeq <= 0.2.3`` on a Windows machine, the instantiation of the
15+
BobcatParser might trigger an SSL certificate error. If you require
16+
``lambeq <= 0.2.3``, you can download the model through this
1817
`link <https://qnlp.cambridgequantum.com/models/bert/latest/model.tar.gz>`_,
1918
extract the archive, and provide the path to the BobcatParser:
2019

2120
.. code-block:: python
2221
2322
from lambeq import BobcatParser
2423
parser = BobcatParser('path/to/model_dir')
24+
25+
Note that using the :py:class:`~lambeq.WebParser` will most likely result in
26+
the same error.
27+
28+
However, this was resolved in release
29+
`0.2.4 <https://github.com/CQCL/lambeq/releases/tag/0.2.4>`_. Please consider
30+
upgrading lambeq:
31+
32+
.. code-block:: bash
33+
34+
pip install --upgrade lambeq

lambeq/text2diagram/bobcat_parser.py

+28-28
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@
1919
import json
2020
import os
2121
from pathlib import Path
22-
import shutil
22+
import requests
2323
import sys
2424
import tarfile
25+
import tempfile
2526
from typing import Any, Iterable, Optional, Union
26-
from urllib.request import urlopen, urlretrieve
2727
import warnings
2828

2929
from discopy.biclosed import Ty
@@ -83,8 +83,7 @@ def model_is_stale(model: str, model_dir: str) -> bool:
8383
return False
8484

8585
try:
86-
with urlopen(url) as f:
87-
remote_version = f.read().strip().decode("utf-8")
86+
remote_version = requests.get(url).text.strip()
8887
except Exception:
8988
return False
9089

@@ -107,41 +106,42 @@ def download_model(
107106
if model_dir is None:
108107
model_dir = get_model_dir(model_name)
109108

110-
class ProgressBar:
111-
bar = None
112-
113-
def update(self, chunk: int, chunk_size: int, size: int) -> None:
114-
if self.bar is None:
115-
self.bar = tqdm(
116-
bar_format='Downloading model: {percentage:3.1f}%|'
117-
'{bar}|{n:.3f}/{total:.3f}GB '
118-
'[{elapsed}<{remaining}]',
119-
total=size/1e9)
120-
warnings.filterwarnings('ignore', category=TqdmWarning)
121-
self.bar.update(chunk_size/1e9)
122-
123-
def close(self):
124-
self.bar.close()
125-
126109
if verbose == VerbosityLevel.TEXT.value:
127110
print('Downloading model...', file=sys.stderr)
128111
if verbose == VerbosityLevel.PROGRESS.value:
129-
progress_bar = ProgressBar()
130-
model_file, headers = urlretrieve(url, reporthook=progress_bar.update)
131-
progress_bar.close()
112+
response = requests.get(url, stream=True)
113+
size = int(response.headers.get('content-length', 0))
114+
block_size = 1024
115+
116+
warnings.filterwarnings('ignore', category=TqdmWarning)
117+
progress_bar = tqdm(
118+
bar_format='Downloading model: {percentage:3.1f}%|'
119+
'{bar}|{n:.3f}/{total:.3f}GB '
120+
'[{elapsed}<{remaining}]',
121+
total=size/1e9)
122+
123+
model_file = tempfile.NamedTemporaryFile()
124+
for data in response.iter_content(block_size):
125+
progress_bar.update(len(data)/1e9)
126+
model_file.write(data)
127+
132128
else:
133-
model_file, headers = urlretrieve(url)
129+
content = requests.get(url).content
130+
model_file = tempfile.NamedTemporaryFile()
131+
model_file.write(content)
134132

135133
# Extract model
134+
model_file.seek(0)
136135
if verbose != VerbosityLevel.SUPPRESS.value:
137136
print('Extracting model...')
138-
with tarfile.open(model_file) as tar:
139-
tar.extractall(model_dir)
137+
tar = tarfile.open(fileobj=model_file)
138+
tar.extractall(model_dir)
139+
model_file.close()
140140

141141
# Download version
142142
ver_url = get_model_url(model_name) + '/' + VERSION_FNAME
143-
ver_file, headers = urlretrieve(ver_url)
144-
shutil.copy(ver_file, model_dir / VERSION_FNAME) # type: ignore
143+
with open(os.path.join(model_dir, VERSION_FNAME), 'wb') as w:
144+
w.write(requests.get(ver_url).content)
145145

146146

147147
class BobcatParseError(Exception):

lambeq/text2diagram/ccgbank_parser.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
from lambeq.text2diagram.ccg_parser import CCGParser
4646
from lambeq.text2diagram.ccg_rule import CCGRule
4747
from lambeq.text2diagram.ccg_tree import CCGTree
48-
from lambeq.text2diagram.ccg_types import CONJ_TAG, CCGAtomicType, str2biclosed
48+
from lambeq.text2diagram.ccg_types import CCGAtomicType, str2biclosed
4949

5050

5151
class CCGBankParseError(Exception):
@@ -408,12 +408,9 @@ def _build_ccgtree(sentence: str, start: int) -> tuple[CCGTree, int]:
408408
child, pos = CCGBankParser._build_ccgtree(sentence, pos)
409409
children.append(child)
410410

411-
if tree_match['ccg_str'].endswith(CONJ_TAG):
412-
rule = CCGRule.CONJUNCTION
413-
else:
414-
rule = CCGRule.infer_rule(
415-
Ty().tensor(*(child.biclosed_type for child in children)),
416-
biclosed_type)
411+
rule = CCGRule.infer_rule(Ty().tensor(*(child.biclosed_type
412+
for child in children)),
413+
biclosed_type)
417414
ccg_tree = CCGTree(rule=rule,
418415
biclosed_type=biclosed_type,
419416
children=children)

lambeq/text2diagram/web_parser.py

+12-19
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,9 @@
1515

1616
__all__ = ['WebParser', 'WebParseError']
1717

18-
import json
18+
import requests
1919
import sys
2020
from typing import Optional
21-
from urllib.error import HTTPError
22-
from urllib.parse import urlencode
23-
from urllib.request import urlopen
2421

2522
from tqdm.auto import tqdm
2623

@@ -34,13 +31,14 @@
3431

3532

3633
class WebParseError(OSError):
37-
def __init__(self, sentence: str, error_code: int) -> None:
34+
def __init__(self, sentence: str) -> None:
3835
self.sentence = sentence
39-
self.error_code = error_code
4036

4137
def __str__(self) -> str:
42-
return (f'Online parsing of sentence {repr(self.sentence)} failed, '
43-
f'Web status code: {self.error_code}.')
38+
return (f'Web parser could not parse {repr(self.sentence)}.'
39+
'Check that you are using the correct URL. '
40+
'If the URL is correct, this means the parser could not parse '
41+
'your sentence.')
4442

4543

4644
class WebParser(CCGParser):
@@ -136,25 +134,20 @@ def sentences2trees(
136134
trees: list[Optional[CCGTree]] = []
137135
if verbose == VerbosityLevel.TEXT.value:
138136
print('Parsing sentences.', file=sys.stderr)
139-
for sent in tqdm(
137+
for sentence in tqdm(
140138
sentences,
141139
desc='Parsing sentences',
142140
leave=False,
143141
disable=verbose != VerbosityLevel.PROGRESS.value):
144-
params = urlencode({'sentence': sent})
145-
url = f'{self.service_url}?{params}'
142+
params = {'sentence': sentence}
146143

147144
try:
148-
with urlopen(url) as f:
149-
data = json.load(f)
150-
except HTTPError as e:
151-
if suppress_exceptions:
152-
tree = None
153-
else:
154-
raise WebParseError(str(sentence), e.code)
155-
except Exception as e:
145+
data = requests.get(self.service_url, params=params).json()
146+
except requests.RequestException as e:
156147
if suppress_exceptions:
157148
tree = None
149+
elif type(e) == requests.JSONDecodeError:
150+
raise WebParseError(str(sentence))
158151
else:
159152
raise e
160153
else:

lambeq/training/numpy_model.py

+31-23
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from discopy.tensor import Diagram
3636
from sympy import default_sort_key, lambdify
3737

38+
from lambeq.training.model import SizedIterable
3839
from lambeq.training.quantum_model import QuantumModel
3940

4041

@@ -100,13 +101,40 @@ def _get_lambda(self, diagram: Diagram) -> Callable[[Any], Any]:
100101
return self.lambdas[diagram]
101102

102103
def diagram_output(*x):
103-
with Tensor.backend('jax'):
104-
result = diagram.lambdify(*self.symbols)(*x).eval().array
104+
with Tensor.backend('jax'), tn.DefaultBackend('jax'):
105+
sub_circuit = self._fast_subs([diagram], x)[0]
106+
result = tn.contractors.auto(*sub_circuit.to_tn()).tensor
105107
return self._normalise_vector(result)
106108

107109
self.lambdas[diagram] = jit(diagram_output)
108110
return self.lambdas[diagram]
109111

112+
def _fast_subs(self,
113+
diagrams: list[Diagram],
114+
weights: SizedIterable) -> list[Diagram]:
115+
"""Substitute weights into a list of parameterised circuit."""
116+
parameters = {k: v for k, v in zip(self.symbols, weights)}
117+
diagrams = pickle.loads(pickle.dumps(diagrams)) # does fast deepcopy
118+
for diagram in diagrams:
119+
for b in diagram._boxes:
120+
if b.free_symbols:
121+
while hasattr(b, 'controlled'):
122+
b._free_symbols = set()
123+
b = b.controlled
124+
syms, values = [], []
125+
for sym in b._free_symbols:
126+
syms.append(sym)
127+
try:
128+
values.append(parameters[sym])
129+
except KeyError:
130+
raise KeyError(f'Unknown symbol {sym!r}.')
131+
b._data = lambdify(syms, b._data)(*values)
132+
b.drawing_name = b.name
133+
b._free_symbols = set()
134+
if hasattr(b, '_phase'):
135+
b._phase = b._data
136+
return diagrams
137+
110138
def get_diagram_output(self, diagrams: list[Diagram]) -> numpy.ndarray:
111139
"""Return the exact prediction for each diagram.
112140
@@ -139,27 +167,7 @@ def get_diagram_output(self, diagrams: list[Diagram]) -> numpy.ndarray:
139167
return numpy.array([diag_f(*self.weights)
140168
for diag_f in lambdified_diagrams])
141169

142-
parameters = {k: v for k, v in zip(self.symbols, self.weights)}
143-
diagrams = pickle.loads(pickle.dumps(diagrams)) # does fast deepcopy
144-
for diagram in diagrams:
145-
for b in diagram._boxes:
146-
if b.free_symbols:
147-
while hasattr(b, 'controlled'):
148-
b._free_symbols = set()
149-
b = b.controlled
150-
syms, values = [], []
151-
for sym in b._free_symbols:
152-
syms.append(sym)
153-
try:
154-
values.append(parameters[sym])
155-
except KeyError:
156-
raise KeyError(f'Unknown symbol {sym!r}.')
157-
b._data = lambdify(syms, b._data)(*values)
158-
b.drawing_name = b.name
159-
b._free_symbols = set()
160-
if hasattr(b, '_phase'):
161-
b._phase = b._data
162-
170+
diagrams = self._fast_subs(diagrams, self.weights)
163171
with Tensor.backend('numpy'):
164172
return numpy.array([
165173
self._normalise_vector(tn.contractors.auto(*d.to_tn()).tensor)

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ module = [
1919
"depccg.*",
2020
"discopy.*",
2121
"jax.*",
22+
"requests.*",
2223
"sympy.*",
2324
"tensornetwork.*",
2425
"tqdm.*",

tests/text2diagram/test_reader.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pytest
2+
from requests.exceptions import MissingSchema
23

34
from discopy import Word
45
from discopy.rigid import Box, Diagram, Id, Spider
@@ -8,6 +9,7 @@
89
from lambeq import (AtomicType, BobcatParser, IQPAnsatz, TreeReader,
910
TreeReaderMode, VerbosityLevel, WebParser, cups_reader,
1011
spiders_reader, stairs_reader)
12+
from lambeq.text2diagram.web_parser import WebParseError
1113

1214

1315
@pytest.fixture
@@ -108,7 +110,7 @@ def test_suppress_exceptions(sentence):
108110
assert bad_reader.sentence2diagram(sentence) is None
109111

110112
bad_reader = TreeReader(bad_parser, suppress_exceptions=False)
111-
with pytest.raises(ValueError):
113+
with pytest.raises(MissingSchema):
112114
bad_reader.sentence2diagram(sentence)
113115

114116

tests/text2diagram/test_web_parser.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from io import StringIO
2+
from shutil import ExecError
23
import pytest
34
from unittest.mock import patch
45

@@ -56,7 +57,7 @@ def test_bad_url():
5657

5758
assert bad_parser.sentence2diagram(
5859
"Need a proper url", suppress_exceptions=True) is None
59-
with pytest.raises(WebParseError):
60+
with pytest.raises(Exception):
6061
bad_parser.sentence2diagram("Need a proper url")
6162

6263

0 commit comments

Comments
 (0)