Source code for fonduer.features.feature_libs.visual_features

"""Fonduer visual feature extractor."""
from typing import Dict, Iterator, List, Set, Tuple, Union

from fonduer.candidates.models import Candidate
from fonduer.candidates.models.span_mention import SpanMention, TemporarySpanMention
from fonduer.utils.data_model_utils import (
    get_visual_aligned_lemmas,
    is_horz_aligned,
    is_vert_aligned,
    is_vert_aligned_center,
    is_vert_aligned_left,
    is_vert_aligned_right,
    same_page,
)

FEAT_PRE = "VIZ_"
DEF_VALUE = 1

unary_vizlib_feats: Dict[str, Set] = {}
multinary_vizlib_feats: Dict[str, Set] = {}


[docs]def extract_visual_features( candidates: Union[Candidate, List[Candidate]], ) -> Iterator[Tuple[int, str, int]]: """Extract visual features. :param candidates: A list of candidates to extract features from """ candidates = candidates if isinstance(candidates, list) else [candidates] for candidate in candidates: args = tuple([m.context for m in candidate.get_mentions()]) if any(not (isinstance(arg, TemporarySpanMention)) for arg in args): raise ValueError( f"Visual feature only accepts Span-type arguments, " f"{type(candidate)}-type found." ) # Unary candidates if len(args) == 1: span = args[0] # Add VisualLib entity features (if applicable) if span.sentence.is_visual(): if span.stable_id not in unary_vizlib_feats: unary_vizlib_feats[span.stable_id] = set() for f, v in _vizlib_unary_features(span): unary_vizlib_feats[span.stable_id].add((f, v)) for f, v in unary_vizlib_feats[span.stable_id]: yield candidate.id, FEAT_PRE + f, v # Multinary candidates else: spans = args # Add VisualLib entity features (if applicable) if all([span.sentence.is_visual() for span in spans]): for i, span in enumerate(spans): prefix = f"e{i}_" if span.stable_id not in unary_vizlib_feats: unary_vizlib_feats[span.stable_id] = set() for f, v in _vizlib_unary_features(span): unary_vizlib_feats[span.stable_id].add((f, v)) for f, v in unary_vizlib_feats[span.stable_id]: yield candidate.id, FEAT_PRE + prefix + f, v if candidate.id not in multinary_vizlib_feats: multinary_vizlib_feats[candidate.id] = set() for f, v in _vizlib_multinary_features(spans): multinary_vizlib_feats[candidate.id].add((f, v)) for f, v in multinary_vizlib_feats[candidate.id]: yield candidate.id, FEAT_PRE + f, v
def _vizlib_unary_features(span: SpanMention) -> Iterator[Tuple[str, int]]: """Visual-related features for a single span.""" if not span.sentence.is_visual(): return for f in get_visual_aligned_lemmas(span): yield f"ALIGNED_{f}", DEF_VALUE for page in set(span.get_attrib_tokens("page")): yield f"PAGE_[{page}]", DEF_VALUE def _vizlib_multinary_features( spans: Tuple[SpanMention, ...] ) -> Iterator[Tuple[str, int]]: """Visual-related features for multiple spans.""" if same_page(spans): yield "SAME_PAGE", DEF_VALUE if is_horz_aligned(spans): yield "HORZ_ALIGNED", DEF_VALUE if is_vert_aligned(spans): yield "VERT_ALIGNED", DEF_VALUE if is_vert_aligned_left(spans): yield "VERT_ALIGNED_LEFT", DEF_VALUE if is_vert_aligned_right(spans): yield "VERT_ALIGNED_RIGHT", DEF_VALUE if is_vert_aligned_center(spans): yield "VERT_ALIGNED_CENTER", DEF_VALUE