Source code for lambeq.experimental.discocirc.reader

# Copyright 2021-2024 Cambridge Quantum Computing Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

# __all__ = ['DisCoCircReader']

from collections import Counter
from collections.abc import Callable, Iterable

from lambeq import AtomicType
from lambeq.backend.grammar import Box, Diagram, Frame, Id, Spider, Ty
from lambeq.core.utils import SentenceBatchType, SentenceType
from lambeq.experimental.discocirc import (CoreferenceResolver,
                                           MaverickCoreferenceResolver,
                                           TreeRewriter,
                                           TreeRewriteRule)
from lambeq.text2diagram import (BobcatParser,
                                 ModelBasedReader,
                                 OncillaParser,
                                 Reader)
from lambeq.text2diagram.pregroup_tree import PregroupTreeNode


NOUN = AtomicType.NOUN


[docs] class DisCoCircReader(Reader): """A reader that converts text to a DisCoCirc diagram."""
[docs] def __init__( self, parser: ModelBasedReader | Callable[[], ModelBasedReader] = BobcatParser, coref_resolver: CoreferenceResolver | Callable[[], CoreferenceResolver] = MaverickCoreferenceResolver ) -> None: if callable(parser): parser = parser() if not isinstance(parser, (BobcatParser, OncillaParser)): raise ValueError(f'{parser} should either be a BobcatParser, ' 'an OncillaParser, or a function ' 'that returns either.') if callable(coref_resolver): coref_resolver = coref_resolver() if not isinstance(coref_resolver, CoreferenceResolver): raise ValueError(f'{coref_resolver} should be a' 'CoreferenceResolver or a function that ' 'returns a CoreferenceResolver.') self.parser = parser self.coref_resolver = coref_resolver
[docs] def sentence2diagram(self, sentence: SentenceType, tokenised: bool = False) -> Diagram | None: raise NotImplementedError('DisCoCircReader only supports ' 'text2circuit presently.')
[docs] def sentences2diagrams( self, sentences: SentenceBatchType, tokenised: bool = False ) -> list[Diagram | None]: raise NotImplementedError('DisCoCircReader only supports ' 'text2circuit presently.')
def _sentence2tree(self, sentence: list[str], break_cycles: bool = True): """Get the pregroup tree from a pregroup diagram generated by the parser.""" if isinstance(self.parser, OncillaParser): root = self.parser._sentence2pregrouptree( sentence, tokenised=True, break_cycles=break_cycles, ) else: d = self.parser.sentence2diagram(sentence, tokenised=True) root = d.to_pregroup_tree(break_cycles=break_cycles) # If diagram returns free wires, corrects by adding missing noun if root.typ not in (Ty('s'), Ty('n')): root.typ = Ty('s') root.children.append(PregroupTreeNode( '', len(sentence)-1, Ty('n'), [], [], [])) return root def _tree2circuital(self, node: PregroupTreeNode, sandwich: bool, pruned_ids: Iterable[int] = (), foliated_frame_labels: bool = True): if sandwich: sdiag, nouns, nids, _ = self._tree2sandwiches_rec( node, pruned_ids=pruned_ids, foliated_frame_labels=foliated_frame_labels) else: sdiag, nouns, nids = self._tree2frames_rec(node, pruned_ids) return sdiag, nouns, nids def _tree2frames_rec(self, node, pruned_ids): """Convert a tree made of `PregroupTreeNode`s into a diagram consisting of nested frames. Implements dragging out of nouns in a single post-order traversal of the tree. """ if node.typ == NOUN and not node.children and node.ind in pruned_ids: return Id(), [], [] if node.typ == NOUN and not node.children: return Id(NOUN), [Box(node.word, Ty(), NOUN)], [node.ind] subdiags_n_nouns = [self._tree2frames_rec( child, pruned_ids) for child in node.children] subdiags = [d for d, _, _ in subdiags_n_nouns if d.layers] nouns = [n for _, ns, _ in subdiags_n_nouns for n in ns] noun_inds = [nid for _, _, nids in subdiags_n_nouns for nid in nids] dom = Ty().tensor(Ty(), *list(filter(lambda t: t.z != 0, node.typ))) if not nouns and not dom: return Id(), [], [] if not subdiags or not any(d.layers for d in subdiags): return Box(node.word, NOUN ** len(nouns) @ dom.l, NOUN ** len(nouns) @ dom.l), nouns, noun_inds else: f = Frame(node.word, NOUN ** len(nouns) @ dom.l, NOUN ** len(nouns) @ dom.l, 0, subdiags) return f, nouns, noun_inds def _tree2sandwiches_rec(self, node, previous_noun=None, pruned_ids: Iterable[int] = (), foliated_frame_labels=True): """Convert a tree made of `PregroupTreeNode`s into a diagram consisting of only boxes. Implements dragging out of nouns in a single post-order traversal of the tree. """ if node.typ == NOUN and not node.children and node.ind in pruned_ids: return Id(), [], [], previous_noun if node.typ == NOUN and not node.children: noun_box = Box(node.word, Ty(), NOUN) return Id(NOUN), [noun_box], [node.ind], node.ind subdiags = [] nouns = [] noun_inds = [] noun2wire = {} noun_cursor = previous_noun bigdiag = Id() if NOUN.l in node.typ or NOUN.r in node.typ: if previous_noun is not None or len(pruned_ids) == 0: noun_inds = [previous_noun] + noun_inds nouns = [Box('', Ty(), NOUN)] + nouns noun2wire[previous_noun] = 0 for child in node.children: (c_subdiag, c_nouns, c_noun_inds, candidate_noun) = self._tree2sandwiches_rec(child, noun_cursor, pruned_ids, foliated_frame_labels) noun_cursor = (candidate_noun if child.typ == NOUN else noun_cursor) pre_ancillae_count = 0 ancilla_nouns = set(noun2wire.values()) wire_ids = [] for j, nid in enumerate(c_noun_inds): if noun2wire.get(nid) is None: noun2wire[nid] = len(noun2wire) pre_ancillae_count += 1 nouns.append(c_nouns[j]) noun_inds.append(c_noun_inds[j]) ancilla_nouns = ancilla_nouns - {noun2wire[nid]} wire_ids.append(noun2wire[nid]) wire_ids = list(sorted(ancilla_nouns))+wire_ids wire_ids_lookup = {wid: len(wire_ids) - i - 1 for i, wid in enumerate(reversed( sorted(wire_ids)))} perm_list = [] for wid in wire_ids: perm_list.append(wire_ids_lookup[wid]) wire_ids_lookup[wid] += 1 permute = Diagram.permutation(NOUN ** len(wire_ids), perm_list) subdiags.append(permute >> Id(NOUN ** len(ancilla_nouns)) @ c_subdiag >> permute.dagger()) subdiags = [subdiag for subdiag in subdiags if subdiag.layers] if not noun2wire: # Floating box, kill return Id(), [], [], noun_cursor bigdiag = Box(node.word + ( '$_{top}$' if (foliated_frame_labels and subdiags) else ''), NOUN ** len(noun2wire), NOUN ** len(noun2wire)) for i, subdiag in enumerate(subdiags[:-1]): if subdiag.layers: bigdiag >>= subdiag @ NOUN ** (len(noun2wire) - len(subdiag.cod)) bigdiag >>= Box(node.word + (f'$_{{{i+1}}}$' if foliated_frame_labels else ''), bigdiag.cod, bigdiag.cod) if subdiags: subdiag = subdiags[-1] bigdiag >>= subdiag @ NOUN ** (len(noun2wire) - len(subdiag.cod)) bigdiag >>= Box(node.word + ('$_{bottom}$' if foliated_frame_labels else ''), bigdiag.cod, bigdiag.cod) return bigdiag, nouns, noun_inds, noun_cursor def _get_index(self, s, pnoun): for j, w in enumerate(s): if w == pnoun: return j return -1 def _prune_indices(self, sentences, corefs, pruned_nouns: Iterable[str] = ()): initial_ids = [] pruned_ids = [[] for _ in sentences] # Find first coref of each pruned noun for pnoun in pruned_nouns: for i, s in enumerate(sentences): j = self._get_index(s, pnoun) if j != -1: initial_ids.append((i, j)) break for i, j in initial_ids: pruned_ids[i].append(j) for p in corefs: if i >= len(p): break elif j in p[i]: for h, hword in enumerate(p): if len(hword) > 0 and h != i: pruned_ids[h].append(hword[0]) break return pruned_ids
[docs] def text2circuit(self, text: str, sandwich: bool = False, break_cycles: bool = True, pruned_nouns: Iterable[str] = (), min_noun_freq: int = 1, rewrite_rules: ( Iterable[TreeRewriteRule | str] | None ) = ('determiner', 'auxiliary'), foliated_frame_labels: bool = True ) -> Diagram: """Return the DisCoCirc diagram for a given text. Parameters ---------- text : str A single string that contains one or multiple sentences. sandwich : bool, default: False If False, returns diagrams using Frames for higher-order boxes, else uses sandwiches, including one box between each subdiagram of a higher-order box. break_cycles : bool, default: True Whether to break any cycles present in the pregroup tree. pruned_nouns : iterable of strings, default: () If any of the nouns in this list are present in the diagram, the corresponding state and wire are removed from the diagram. min_noun_freq: int, default: 1 Mininum number of times a noun needs to be referenced to appear in the circuit. rewrite_rules : list of `TreeRewriteRule` or str List of rewrite rules to apply to the pregroup tree before conversion to a circuit. foliated_frame_labels : bool, default: True When sandwich is True, setting to True labels frames with numbered suffixes. False makes all sandwich layers have the same labels. Returns ------- Diagram A DisCoCirc diagram for the given text. """ sentences, corefs = self.coref_resolver.tokenise_and_coref(text) corefd = self.coref_resolver.dict_from_corefs(corefs) noun_counts = Counter(corefd.values()) freq_pruned_ids = [nid for nid, count in noun_counts.items() if count < min_noun_freq] pruned_nouns = set(pruned_nouns).union( {sentences[i][j] for (i, j) in freq_pruned_ids}) pruned_ids = self._prune_indices(sentences, corefs, pruned_nouns) rewriter = TreeRewriter(rewrite_rules) bigdiag = Id() noun2wire = {} noun_boxes = [] for i, sentence in enumerate(sentences): tree = self._sentence2tree(sentence, break_cycles) tree_toks = tree.get_words() tree_toks_indxs = tree.get_word_indices() reidxr = self._calculate_reindices(sentence, tree_toks, tree_toks_indxs) reidxr[None] = None tree = rewriter(tree) tree = self._reindex_nodes(tree, reidxr) (sdiag, nouns, nids) = self._tree2circuital(tree, sandwich, pruned_ids[i], foliated_frame_labels) qual_nids = [(i, nid) for nid in nids] unique_nids = [corefd.get(nid, nid) for nid in qual_nids] wire_ids = [] ancilla_nouns = set(noun2wire.values()) pre_ancillae_count = 0 for j, nid in enumerate(unique_nids): if noun2wire.get(nid) is None: noun2wire[nid] = len(noun2wire) pre_ancillae_count += 1 noun_boxes.append(nouns[j]) ancilla_nouns = ancilla_nouns - {noun2wire[nid]} wire_ids.append(noun2wire[nid]) wire_ids = list(sorted(ancilla_nouns))+wire_ids wire_ids_lookup = {wid: len(wire_ids) - i - 1 for i, wid in enumerate(reversed( sorted(wire_ids)))} perm_list = [] for wid in wire_ids: perm_list.append(wire_ids_lookup[wid]) wire_ids_lookup[wid] += 1 wire_counter = Counter(wire_ids) spidering = Id().tensor( *[Id(NOUN) if wire_counter[wid] == 1 else Spider(NOUN, 1, wire_counter[wid]) for wid in range(max(wire_ids)+1)]) permute = Diagram.permutation(NOUN ** len(wire_ids), perm_list) matcher = spidering >> permute bigdiag = (bigdiag @ (NOUN ** pre_ancillae_count) >> matcher >> (NOUN ** len(ancilla_nouns)) @ sdiag >> matcher.dagger()) return Id().tensor(*noun_boxes) >> bigdiag
def _calculate_reindices(self, orig_toks, parsed_toks, parsed_toks_indxs): reindexer = {} j = 0 for i, otok in zip(parsed_toks_indxs, parsed_toks): while j < len(orig_toks) and orig_toks[j] != otok: j += 1 reindexer[i] = j j += 1 return reindexer def _reindex_nodes(self, node, reindexer): node.ind = reindexer[node.ind] node.children = [self._reindex_nodes(child, reindexer) for child in node.children] return node