Source code for segram.nlp.pipeline.coref

"""Segram coreference pipeline component."""
from typing import Any, Sequence, Self
import os
import spacy
from spacy.language import Language
from spacy.tokens import Doc
from spacy.training import Alignment
from .base import Segram
from ... import __title__
from ...utils.meta import get_cname


[docs] class Coref: """Coreference resolution pipeline component based on :mod:`spacy` coref component. Attributes ---------- name Pipe name. model Language model for coreference resolution. """
[docs] def __init__( self, nlp: Language, name: str, model: Language, components: Sequence[str] | None = None ) -> None: """Initilization method. Parameters ---------- nlp Main language model. model Name of a coreference language model. components Names of pipeline component names to include. Use all if ``None``. Raises ------ ValueError If ``components`` are empty but not ``None``. """ try: segram = dict(nlp.pipeline)[__title__] self.alias = segram.alias except KeyError as exc: raise RuntimeError( f"'{name}' component can be " f"initialized only after '{__title__}'" ) from exc self.nlp = nlp self.name = name self.model = model if components is not None and not components: raise ValueError( f"'{get_cname(self)}' pipe initizialized " "with empty 'components' list" ) if components: self.model.select_pipes(enable=components)
def __call__(self, doc: Doc) -> Doc: cdoc = self.model(doc.text) s1 = list(map(str, doc)) s2 = list(map(str, cdoc)) align = Alignment.from_strings(s1, s2) for spans in cdoc.spans.values(): cluster = [ int(align.y2x.data[t.i]) for s in spans for t in s ] self.set_corefs(doc, cluster) getattr(doc._, f"{self.alias}_meta")["coref"] = \ Segram.get_model_info(self.model) return doc
[docs] def set_corefs(self, doc: Doc, cluster: Sequence[int]) -> None: """Set proper coreferences from pronoun tokens to closest non-pronoun neighbors within the ``cluster``. Notes ----- Coreferences are stored as token indexes (integers) in ``_ref`` custom attribute on tokens. """ # pylint: disable=protected-access alias = self.alias proper = [] pronouns = [] for i in cluster: if getattr(doc[i]._, alias+"_sns").is_pron: pronouns.append(i) else: proper.append(i) for i in pronouns: closest = None for j in proper: if closest is None or abs(i-j) < closest: closest = j if closest is not None: corefs = set() corefs.add(closest) for conj in doc[closest].conjuncts: corefs.add(conj.i) setattr(doc[i]._, f"{alias}_corefs", tuple(sorted(corefs)))
[docs] @classmethod def from_model( cls, nlp: Language, name: str, model: str, components: Sequence[str] | None = None, **kwds: Any ) -> Self: """Initialize from model name. ``**kwds`` are passed to :func:`spacy.load`. """ model = spacy.load(model, **kwds) return cls(nlp, name, model, components)
[docs] def to_disk(self, path: str | bytes | os.PathLike, **kwds: Any) -> None: """Serialize the coreference model to disk.""" self.model.to_disk(path, **kwds)
[docs] def from_disk(self, path: str | bytes | os.PathLike, **kwds: Any) -> Self: """Load from disk.""" self.model = spacy.load(path, **kwds) return self