# Copyright 2021-2024 Cambridge Quantum Computing Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
# __all__ = ['DisCoCircReader']
from collections import Counter
from collections.abc import Callable, Iterable
from lambeq import AtomicType
from lambeq.backend.grammar import Box, Diagram, Frame, Id, Spider, Ty
from lambeq.core.utils import SentenceBatchType, SentenceType
from lambeq.experimental.discocirc import (CoreferenceResolver,
MaverickCoreferenceResolver,
TreeRewriter,
TreeRewriteRule)
from lambeq.text2diagram import (BobcatParser,
ModelBasedReader,
OncillaParser,
Reader)
from lambeq.text2diagram.pregroup_tree import PregroupTreeNode
NOUN = AtomicType.NOUN
[docs]
class DisCoCircReader(Reader):
"""A reader that converts text to a DisCoCirc diagram."""
[docs]
def __init__(
self,
parser:
ModelBasedReader
| Callable[[], ModelBasedReader] = BobcatParser,
coref_resolver:
CoreferenceResolver
| Callable[[], CoreferenceResolver] = MaverickCoreferenceResolver
) -> None:
if callable(parser):
parser = parser()
if not isinstance(parser, (BobcatParser, OncillaParser)):
raise ValueError(f'{parser} should either be a BobcatParser, '
'an OncillaParser, or a function '
'that returns either.')
if callable(coref_resolver):
coref_resolver = coref_resolver()
if not isinstance(coref_resolver, CoreferenceResolver):
raise ValueError(f'{coref_resolver} should be a'
'CoreferenceResolver or a function that '
'returns a CoreferenceResolver.')
self.parser = parser
self.coref_resolver = coref_resolver
[docs]
def sentence2diagram(self,
sentence: SentenceType,
tokenised: bool = False) -> Diagram | None:
raise NotImplementedError('DisCoCircReader only supports '
'text2circuit presently.')
[docs]
def sentences2diagrams(
self,
sentences: SentenceBatchType,
tokenised: bool = False
) -> list[Diagram | None]:
raise NotImplementedError('DisCoCircReader only supports '
'text2circuit presently.')
def _sentence2tree(self, sentence: list[str], break_cycles: bool = True):
"""Get the pregroup tree from a pregroup diagram generated by
the parser."""
if isinstance(self.parser, OncillaParser):
root = self.parser._sentence2pregrouptree(
sentence, tokenised=True, break_cycles=break_cycles,
)
else:
d = self.parser.sentence2diagram(sentence, tokenised=True)
root = d.to_pregroup_tree(break_cycles=break_cycles)
# If diagram returns free wires, corrects by adding missing noun
if root.typ not in (Ty('s'), Ty('n')):
root.typ = Ty('s')
root.children.append(PregroupTreeNode(
'', len(sentence)-1, Ty('n'), [], [], []))
return root
def _tree2circuital(self,
node: PregroupTreeNode,
sandwich: bool,
pruned_ids: Iterable[int] = (),
foliated_frame_labels: bool = True):
if sandwich:
sdiag, nouns, nids, _ = self._tree2sandwiches_rec(
node,
pruned_ids=pruned_ids,
foliated_frame_labels=foliated_frame_labels)
else:
sdiag, nouns, nids = self._tree2frames_rec(node, pruned_ids)
return sdiag, nouns, nids
def _tree2frames_rec(self, node, pruned_ids):
"""Convert a tree made of `PregroupTreeNode`s into a diagram
consisting of nested frames. Implements dragging out of nouns in
a single post-order traversal of the tree.
"""
if node.typ == NOUN and not node.children and node.ind in pruned_ids:
return Id(), [], []
if node.typ == NOUN and not node.children:
return Id(NOUN), [Box(node.word, Ty(), NOUN)], [node.ind]
subdiags_n_nouns = [self._tree2frames_rec(
child, pruned_ids) for child in node.children]
subdiags = [d for d, _, _ in subdiags_n_nouns if d.layers]
nouns = [n for _, ns, _ in subdiags_n_nouns for n in ns]
noun_inds = [nid for _, _, nids in subdiags_n_nouns
for nid in nids]
dom = Ty().tensor(Ty(),
*list(filter(lambda t: t.z != 0, node.typ)))
if not nouns and not dom:
return Id(), [], []
if not subdiags or not any(d.layers for d in subdiags):
return Box(node.word,
NOUN ** len(nouns) @ dom.l,
NOUN ** len(nouns) @ dom.l), nouns, noun_inds
else:
f = Frame(node.word,
NOUN ** len(nouns) @ dom.l,
NOUN ** len(nouns) @ dom.l,
0,
subdiags)
return f, nouns, noun_inds
def _tree2sandwiches_rec(self,
node,
previous_noun=None,
pruned_ids: Iterable[int] = (),
foliated_frame_labels=True):
"""Convert a tree made of `PregroupTreeNode`s into a diagram
consisting of only boxes. Implements dragging out of nouns in
a single post-order traversal of the tree.
"""
if node.typ == NOUN and not node.children and node.ind in pruned_ids:
return Id(), [], [], previous_noun
if node.typ == NOUN and not node.children:
noun_box = Box(node.word, Ty(), NOUN)
return Id(NOUN), [noun_box], [node.ind], node.ind
subdiags = []
nouns = []
noun_inds = []
noun2wire = {}
noun_cursor = previous_noun
bigdiag = Id()
if NOUN.l in node.typ or NOUN.r in node.typ:
if previous_noun is not None or len(pruned_ids) == 0:
noun_inds = [previous_noun] + noun_inds
nouns = [Box('', Ty(), NOUN)] + nouns
noun2wire[previous_noun] = 0
for child in node.children:
(c_subdiag,
c_nouns,
c_noun_inds,
candidate_noun) = self._tree2sandwiches_rec(child,
noun_cursor,
pruned_ids,
foliated_frame_labels)
noun_cursor = (candidate_noun if child.typ == NOUN
else noun_cursor)
pre_ancillae_count = 0
ancilla_nouns = set(noun2wire.values())
wire_ids = []
for j, nid in enumerate(c_noun_inds):
if noun2wire.get(nid) is None:
noun2wire[nid] = len(noun2wire)
pre_ancillae_count += 1
nouns.append(c_nouns[j])
noun_inds.append(c_noun_inds[j])
ancilla_nouns = ancilla_nouns - {noun2wire[nid]}
wire_ids.append(noun2wire[nid])
wire_ids = list(sorted(ancilla_nouns))+wire_ids
wire_ids_lookup = {wid: len(wire_ids) - i - 1
for i, wid in enumerate(reversed(
sorted(wire_ids)))}
perm_list = []
for wid in wire_ids:
perm_list.append(wire_ids_lookup[wid])
wire_ids_lookup[wid] += 1
permute = Diagram.permutation(NOUN ** len(wire_ids), perm_list)
subdiags.append(permute
>> Id(NOUN ** len(ancilla_nouns)) @ c_subdiag
>> permute.dagger())
subdiags = [subdiag for subdiag in subdiags if subdiag.layers]
if not noun2wire:
# Floating box, kill
return Id(), [], [], noun_cursor
bigdiag = Box(node.word + (
'$_{top}$' if (foliated_frame_labels
and subdiags) else ''),
NOUN ** len(noun2wire),
NOUN ** len(noun2wire))
for i, subdiag in enumerate(subdiags[:-1]):
if subdiag.layers:
bigdiag >>= subdiag @ NOUN ** (len(noun2wire)
- len(subdiag.cod))
bigdiag >>= Box(node.word
+ (f'$_{{{i+1}}}$'
if foliated_frame_labels else ''),
bigdiag.cod,
bigdiag.cod)
if subdiags:
subdiag = subdiags[-1]
bigdiag >>= subdiag @ NOUN ** (len(noun2wire) - len(subdiag.cod))
bigdiag >>= Box(node.word
+ ('$_{bottom}$'
if foliated_frame_labels else ''),
bigdiag.cod,
bigdiag.cod)
return bigdiag, nouns, noun_inds, noun_cursor
def _get_index(self, s, pnoun):
for j, w in enumerate(s):
if w == pnoun:
return j
return -1
def _prune_indices(self,
sentences,
corefs,
pruned_nouns: Iterable[str] = ()):
initial_ids = []
pruned_ids = [[] for _ in sentences]
# Find first coref of each pruned noun
for pnoun in pruned_nouns:
for i, s in enumerate(sentences):
j = self._get_index(s, pnoun)
if j != -1:
initial_ids.append((i, j))
break
for i, j in initial_ids:
pruned_ids[i].append(j)
for p in corefs:
if i >= len(p):
break
elif j in p[i]:
for h, hword in enumerate(p):
if len(hword) > 0 and h != i:
pruned_ids[h].append(hword[0])
break
return pruned_ids
[docs]
def text2circuit(self,
text: str,
sandwich: bool = False,
break_cycles: bool = True,
pruned_nouns: Iterable[str] = (),
min_noun_freq: int = 1,
rewrite_rules: (
Iterable[TreeRewriteRule | str] | None
) = ('determiner', 'auxiliary'),
foliated_frame_labels: bool = True
) -> Diagram:
"""Return the DisCoCirc diagram for a given text.
Parameters
----------
text : str
A single string that contains one or multiple sentences.
sandwich : bool, default: False
If False, returns diagrams using Frames for higher-order
boxes, else uses sandwiches, including one box between each
subdiagram of a higher-order box.
break_cycles : bool, default: True
Whether to break any cycles present in the pregroup tree.
pruned_nouns : iterable of strings, default: ()
If any of the nouns in this list are present in the diagram,
the corresponding state and wire are removed from the
diagram.
min_noun_freq: int, default: 1
Mininum number of times a noun needs to be referenced to
appear in the circuit.
rewrite_rules : list of `TreeRewriteRule` or str
List of rewrite rules to apply to the pregroup tree
before conversion to a circuit.
foliated_frame_labels : bool, default: True
When sandwich is True, setting to True labels frames with
numbered suffixes. False makes all sandwich layers have the
same labels.
Returns
-------
Diagram
A DisCoCirc diagram for the given text.
"""
sentences, corefs = self.coref_resolver.tokenise_and_coref(text)
corefd = self.coref_resolver.dict_from_corefs(corefs)
noun_counts = Counter(corefd.values())
freq_pruned_ids = [nid for nid, count in noun_counts.items()
if count < min_noun_freq]
pruned_nouns = set(pruned_nouns).union(
{sentences[i][j] for (i, j) in freq_pruned_ids})
pruned_ids = self._prune_indices(sentences, corefs, pruned_nouns)
rewriter = TreeRewriter(rewrite_rules)
bigdiag = Id()
noun2wire = {}
noun_boxes = []
for i, sentence in enumerate(sentences):
tree = self._sentence2tree(sentence, break_cycles)
tree_toks = tree.get_words()
tree_toks_indxs = tree.get_word_indices()
reidxr = self._calculate_reindices(sentence, tree_toks,
tree_toks_indxs)
reidxr[None] = None
tree = rewriter(tree)
tree = self._reindex_nodes(tree, reidxr)
(sdiag,
nouns,
nids) = self._tree2circuital(tree,
sandwich,
pruned_ids[i],
foliated_frame_labels)
qual_nids = [(i, nid) for nid in nids]
unique_nids = [corefd.get(nid, nid) for nid in qual_nids]
wire_ids = []
ancilla_nouns = set(noun2wire.values())
pre_ancillae_count = 0
for j, nid in enumerate(unique_nids):
if noun2wire.get(nid) is None:
noun2wire[nid] = len(noun2wire)
pre_ancillae_count += 1
noun_boxes.append(nouns[j])
ancilla_nouns = ancilla_nouns - {noun2wire[nid]}
wire_ids.append(noun2wire[nid])
wire_ids = list(sorted(ancilla_nouns))+wire_ids
wire_ids_lookup = {wid: len(wire_ids) - i - 1
for i, wid in enumerate(reversed(
sorted(wire_ids)))}
perm_list = []
for wid in wire_ids:
perm_list.append(wire_ids_lookup[wid])
wire_ids_lookup[wid] += 1
wire_counter = Counter(wire_ids)
spidering = Id().tensor(
*[Id(NOUN) if wire_counter[wid] == 1
else Spider(NOUN,
1,
wire_counter[wid])
for wid in range(max(wire_ids)+1)])
permute = Diagram.permutation(NOUN ** len(wire_ids), perm_list)
matcher = spidering >> permute
bigdiag = (bigdiag @ (NOUN ** pre_ancillae_count)
>> matcher
>> (NOUN ** len(ancilla_nouns)) @ sdiag
>> matcher.dagger())
return Id().tensor(*noun_boxes) >> bigdiag
def _calculate_reindices(self,
orig_toks,
parsed_toks,
parsed_toks_indxs):
reindexer = {}
j = 0
for i, otok in zip(parsed_toks_indxs, parsed_toks):
while j < len(orig_toks) and orig_toks[j] != otok:
j += 1
reindexer[i] = j
j += 1
return reindexer
def _reindex_nodes(self, node, reindexer):
node.ind = reindexer[node.ind]
node.children = [self._reindex_nodes(child, reindexer)
for child in node.children]
return node