"""Fonduer data model utils."""
import logging
from functools import lru_cache
from typing import Callable, Iterable, List, Set, Union
from fonduer.candidates.models import Candidate, Mention
from fonduer.candidates.models.span_mention import TemporarySpanMention
@lru_cache(maxsize=1024)
def _to_span(
x: Union[Candidate, Mention, TemporarySpanMention], idx: int = 0
) -> TemporarySpanMention:
"""Convert a Candidate, Mention, or Span to a span."""
if isinstance(x, Candidate):
return x[idx].context
elif isinstance(x, Mention):
return x.context
elif isinstance(x, TemporarySpanMention):
return x
else:
raise ValueError(f"{type(x)} is an invalid argument type")
@lru_cache(maxsize=1024)
def _to_spans(
x: Union[Candidate, Mention, TemporarySpanMention]
) -> List[TemporarySpanMention]:
"""Convert a Candidate, Mention, or Span to a list of spans."""
if isinstance(x, Candidate):
return [_to_span(m) for m in x]
elif isinstance(x, Mention):
return [x.context]
elif isinstance(x, TemporarySpanMention):
return [x]
else:
raise ValueError(f"{type(x)} is an invalid argument type")
[docs]def is_superset(a: Iterable, b: Iterable) -> bool:
"""Check if a is a superset of b.
This is typically used to check if ALL of a list of sentences is in the
ngrams returned by an lf_helper.
:param a: A collection of items
:param b: A collection of items
"""
return set(a).issuperset(b)
[docs]def overlap(a: Iterable, b: Iterable) -> bool:
"""Check if a overlaps b.
This is typically used to check if ANY of a list of sentences is in the
ngrams returned by an lf_helper.
:param a: A collection of items
:param b: A collection of items
"""
return not set(a).isdisjoint(b)
[docs]def get_matches(
lf: Callable, candidate_set: Set[Candidate], match_values: List[int] = [1, -1]
) -> List[Candidate]:
"""Return a list of candidates that are matched by a particular LF.
A simple helper function to see how many matches (non-zero by default) an
LF gets.
:param lf: The labeling function to apply to the candidate_set
:param candidate_set: The set of candidates to evaluate
:param match_values: An option list of the values to consider as matched.
[1, -1] by default.
"""
logger = logging.getLogger(__name__)
matches: List[Candidate] = []
for c in candidate_set:
label = lf(c)
if label in match_values:
matches.append(c)
logger.info(f"{len(matches)} matches")
return matches