Source code for fonduer.candidates.models.implicit_span_mention

"""Fonduer implicit span mention model."""
from typing import Any, Dict, List, Optional, Type

from sqlalchemy import Column, ForeignKey, Integer, String, UniqueConstraint
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import backref, relationship
from sqlalchemy.sql import text
from sqlalchemy.types import PickleType

from fonduer.candidates.models.span_mention import TemporarySpanMention
from fonduer.parser.models.context import Context
from fonduer.parser.models.sentence import Sentence
from fonduer.parser.models.utils import split_stable_id


class TemporaryImplicitSpanMention(TemporarySpanMention):
    """The TemporaryContext version of ImplicitSpanMention."""

    def __init__(
        self,
        sentence: Sentence,
        char_start: int,
        char_end: int,
        expander_key: str,
        position: int,
        text: str,
        words: List[str],
        lemmas: List[str],
        pos_tags: List[str],
        ner_tags: List[str],
        dep_parents: List[int],
        dep_labels: List[str],
        page: List[Optional[int]],
        top: List[Optional[int]],
        left: List[Optional[int]],
        bottom: List[Optional[int]],
        right: List[Optional[int]],
        meta: Any = None,
    ) -> None:
        """Initialize TemporaryImplicitSpanMention."""
        super().__init__(sentence, char_start, char_end, meta)
        self.expander_key = expander_key
        self.position = position
        self.text = text
        self.words = words
        self.lemmas = lemmas
        self.pos_tags = pos_tags
        self.ner_tags = ner_tags
        self.dep_parents = dep_parents
        self.dep_labels = dep_labels
        self.page = page
        self.top = top
        self.left = left
        self.bottom = bottom
        self.right = right

    def __len__(self) -> int:
        """Get the length of the mention."""
        return sum(map(len, self.words))

    def __eq__(self, other: object) -> bool:
        """Check if the mention is equal to another mention."""
        if not isinstance(other, TemporaryImplicitSpanMention):
            return NotImplemented
        return (
            self.sentence == other.sentence
            and self.char_start == other.char_start
            and self.char_end == other.char_end
            and self.expander_key == other.expander_key
            and self.position == other.position
        )

    def __ne__(self, other: object) -> bool:
        """Check if the mention is not equal to another mention."""
        if not isinstance(other, TemporaryImplicitSpanMention):
            return NotImplemented
        return (
            self.sentence != other.sentence
            or self.char_start != other.char_start
            or self.char_end != other.char_end
            or self.expander_key != other.expander_key
            or self.position != other.position
        )

    def __hash__(self) -> int:
        """Get the hash value of mention."""
        return (
            hash(self.sentence)
            + hash(self.char_start)
            + hash(self.char_end)
            + hash(self.expander_key)
            + hash(self.position)
        )

    def get_stable_id(self) -> str:
        """Return a stable id."""
        doc_id, _, idx = split_stable_id(self.sentence.stable_id)
        parent_doc_char_start = idx[0]
        return (
            f"{self.sentence.document.name}"
            f"::"
            f"{self._get_polymorphic_identity()}"
            f":"
            f"{parent_doc_char_start + self.char_start}"
            f":"
            f"{parent_doc_char_start + self.char_end}"
            f":"
            f"{self.expander_key}"
            f":"
            f"{self.position}"
        )

    def _get_table(self) -> Type["ImplicitSpanMention"]:
        return ImplicitSpanMention

    def _get_polymorphic_identity(self) -> str:
        return "implicit_span_mention"

    def _get_insert_args(self) -> Dict[str, Any]:
        return {
            "sentence_id": self.sentence.id,
            "char_start": self.char_start,
            "char_end": self.char_end,
            "expander_key": self.expander_key,
            "position": self.position,
            "text": self.text,
            "words": self.words,
            "lemmas": self.lemmas,
            "pos_tags": self.pos_tags,
            "ner_tags": self.ner_tags,
            "dep_parents": self.dep_parents,
            "dep_labels": self.dep_labels,
            "page": self.page,
            "top": self.top,
            "left": self.left,
            "bottom": self.bottom,
            "right": self.right,
            "meta": self.meta,
        }

    def get_attrib_tokens(self, a: str = "words") -> List:
        """Get the tokens of sentence attribute *a*.

        Intuitively, like calling::

            implicit_span.a


        :param a: The attribute to get tokens for.
        :return: The tokens of sentence attribute defined by *a* for the span.
        """
        return self.__getattribute__(a)

    def get_attrib_span(self, a: str, sep: str = "") -> str:
        """Get the span of sentence attribute *a*.

        Intuitively, like calling::

            sep.join(implicit_span.a)

        :param a: The attribute to get a span for.
        :param sep: The separator to use for the join,
                    or to be removed from text if a="words".
        :return: The joined tokens, or text if a="words".
        """
        if a == "words":
            return self.text.replace(sep, "")
        else:
            return sep.join([str(n) for n in self.get_attrib_tokens(a)])

    def __getitem__(self, key: slice) -> "TemporaryImplicitSpanMention":
        """Slice operation returns a new candidate sliced according to **char index**.

        Note that the slicing is w.r.t. the candidate range (not the abs.
        sentence char indexing)
        """
        if isinstance(key, slice):
            char_start = (
                self.char_start if key.start is None else self.char_start + key.start
            )
            if key.stop is None:
                char_end = self.char_end
            elif key.stop >= 0:
                char_end = self.char_start + key.stop - 1
            else:
                char_end = self.char_end + key.stop
            return self._get_instance(
                sentence=self.sentence,
                char_start=char_start,
                char_end=char_end,
                expander_key=self.expander_key,
                position=self.position,
                text=text,
                words=self.words,
                lemmas=self.lemmas,
                pos_tags=self.pos_tags,
                ner_tags=self.ner_tags,
                dep_parents=self.dep_parents,
                dep_labels=self.dep_labels,
                page=self.page,
                top=self.top,
                left=self.left,
                bottom=self.bottom,
                right=self.right,
                meta=self.meta,
            )
        else:
            raise NotImplementedError()

    def __repr__(self) -> str:
        """Represent the mention as a string."""
        return (
            f"{self.__class__.__name__}"
            f"("
            f'"{self.get_span()}", '
            f"sentence={self.sentence.id}, "
            f"words=[{self.get_word_start_index()},{self.get_word_end_index()}], "
            f"position=[{self.position}]"
            f")"
        )

    def _get_instance(self, **kwargs: Any) -> "TemporaryImplicitSpanMention":
        return TemporaryImplicitSpanMention(**kwargs)


[docs]class ImplicitSpanMention(Context, TemporaryImplicitSpanMention): """A span of characters that may not appear verbatim in the source text. It is identified by Context id, character-index start and end (inclusive), as well as a key representing what 'expander' function drew the ImplicitSpanMention from an existing SpanMention, and a position (where position=0 corresponds to the first ImplicitSpanMention produced from the expander function). The character-index start and end point to the segment of text that was expanded to produce the ImplicitSpanMention. """ __tablename__ = "implicit_span_mention" #: The unique id of the ``ImplicitSpanMention``. id = Column(Integer, ForeignKey("context.id", ondelete="CASCADE"), primary_key=True) #: The id of the parent ``Sentence``. sentence_id = Column( Integer, ForeignKey("context.id", ondelete="CASCADE"), primary_key=True ) #: The parent ``Sentence``. sentence = relationship( "Context", backref=backref("implicit_spans", cascade="all, delete-orphan"), foreign_keys=sentence_id, ) #: The starting character-index of the ``ImplicitSpanMention``. char_start = Column(Integer, nullable=False) #: The ending character-index of the ``ImplicitSpanMention`` (inclusive). char_end = Column(Integer, nullable=False) #: The key representing the expander function which produced this # ``ImplicitSpanMention``. expander_key = Column(String, nullable=False) #: The position of the ``ImplicitSpanMention`` where position=0 is the first #: ``ImplicitSpanMention`` produced by the expander. position = Column(Integer, nullable=False) #: The raw text of the ``ImplicitSpanMention``. text = Column(String) #: A list of the words in the ``ImplicitSpanMention``. words = Column(postgresql.ARRAY(String), nullable=False) #: A list of the lemmas for each word in the ``ImplicitSpanMention``. lemmas = Column(postgresql.ARRAY(String)) #: A list of the POS tags for each word in the ``ImplicitSpanMention``. pos_tags = Column(postgresql.ARRAY(String)) #: A list of the NER tags for each word in the ``ImplicitSpanMention``. ner_tags = Column(postgresql.ARRAY(String)) #: A list of the dependency parents for each word in the ``ImplicitSpanMention``. dep_parents = Column(postgresql.ARRAY(Integer)) #: A list of the dependency labels for each word in the ``ImplicitSpanMention``. dep_labels = Column(postgresql.ARRAY(String)) #: A list of the page number each word in the ``ImplicitSpanMention``. page = Column(postgresql.ARRAY(Integer)) #: A list of each word's TOP bounding box coordinate in the # ``ImplicitSpanMention``. top = Column(postgresql.ARRAY(Integer)) #: A list of each word's LEFT bounding box coordinate in the # ``ImplicitSpanMention``. left = Column(postgresql.ARRAY(Integer)) #: A list of each word's BOTTOM bounding box coordinate in the # ``ImplicitSpanMention``. bottom = Column(postgresql.ARRAY(Integer)) #: A list of each word's RIGHT bounding box coordinate in the # ``ImplicitSpanMention``. right = Column(postgresql.ARRAY(Integer)) #: Pickled metadata about the ``ImplicitSpanMention``. meta = Column(PickleType) __table_args__ = ( UniqueConstraint(sentence_id, char_start, char_end, expander_key, position), ) __mapper_args__ = { "polymorphic_identity": "implicit_span_mention", "inherit_condition": (id == Context.id), } def __init__(self, tc: TemporaryImplicitSpanMention): """Initialize ImplicitSpanMention.""" self.stable_id = tc.get_stable_id() self.sentence = tc.sentence self.char_start = tc.char_start self.char_end = tc.char_end self.expander_key = tc.expander_key self.position = tc.position self.text = tc.text self.words = tc.words self.lemmas = tc.lemmas self.pos_tags = tc.pos_tags self.ner_tags = tc.ner_tags self.dep_parents = tc.dep_parents self.dep_labels = tc.dep_labels self.page = tc.page self.top = tc.top self.left = tc.left self.bottom = tc.bottom self.right = tc.right self.meta = tc.meta def _get_instance(self, **kwargs: Any) -> "ImplicitSpanMention": return ImplicitSpanMention(**kwargs) # We redefine these to use default semantics, overriding the operators # inherited from TemporarySpan def __eq__(self, other: object) -> bool: """Check if the mention is equal to another mention.""" if not isinstance(other, ImplicitSpanMention): return NotImplemented return self is other def __ne__(self, other: object) -> bool: """Check if the mention is not equal to another mention.""" if not isinstance(other, ImplicitSpanMention): return NotImplemented return self is not other def __hash__(self) -> int: """Get the hash value of mention.""" return id(self)