Source code for fonduer.candidates.candidates

import logging
from builtins import range
from itertools import product
from typing import (
    Any,
    Callable,
    Collection,
    Iterable,
    List,
    Optional,
    Tuple,
    Type,
    Union,
)

from sqlalchemy.orm import Session

from fonduer.candidates.models import Candidate, Mention
from fonduer.parser.models.document import Document
from fonduer.utils.udf import UDF, UDFRunner
from fonduer.utils.utils import get_set_of_stable_ids

logger = logging.getLogger(__name__)


[docs]class CandidateExtractor(UDFRunner): """An operator to extract Candidate objects from a Context. :Example: Assuming we have already defined a Part and Temp ``Mention`` subclass, and a throttler called templ_throttler, we can create a candidate extractor as follows:: PartTemp = candidate_subclass("PartTemp", [Part, Temp]) candidate_extractor = CandidateExtractor( session, [PartTemp], throttlers=[temp_throttler] ) :param session: An initialized database session. :param candidate_classes: The types of relation to extract, defined using :func:`fonduer.candidates.candidate_subclass`. :param throttlers: optional functions for filtering out candidates which returns a Boolean expressing whether or not the candidate should be instantiated. :type throttlers: list of throttlers. :param self_relations: Boolean indicating whether to extract Candidates that relate the same context. Only applies to binary relations. :param nested_relations: Boolean indicating whether to extract Candidates that relate one Context with another that contains it. Only applies to binary relations. :param symmetric_relations: Boolean indicating whether to extract symmetric Candidates, i.e., rel(A,B) and rel(B,A), where A and B are Contexts. Only applies to binary relations. :param parallelism: The number of processes to use in parallel for calls to apply(). :raises ValueError: If throttlers are provided, but a throtters are not the same length as candidate classes. """ def __init__( self, session: Session, candidate_classes: List[Type[Candidate]], throttlers: Optional[List[Callable[[Tuple[Mention, ...]], bool]]] = None, self_relations: bool = False, nested_relations: bool = False, symmetric_relations: bool = True, parallelism: int = 1, ) -> None: """Set throttlers match candidate_classes if not provide.""" if throttlers is None: throttlers = [None] * len(candidate_classes) """Initialize the CandidateExtractor.""" super().__init__( session, CandidateExtractorUDF, parallelism=parallelism, candidate_classes=candidate_classes, throttlers=throttlers, self_relations=self_relations, nested_relations=nested_relations, symmetric_relations=symmetric_relations, ) # Check that arity is sensible if len(candidate_classes) != len(throttlers): raise ValueError( "Provided different number of throttlers and candidate classes." ) self.candidate_classes = candidate_classes
[docs] def apply( # type: ignore self, docs: Collection[Document], split: int = 0, clear: bool = True, parallelism: Optional[int] = None, progress_bar: bool = True, ) -> None: """Run the CandidateExtractor. :Example: To extract candidates from a set of training documents using 4 cores:: candidate_extractor.apply(train_docs, split=0, parallelism=4) :param docs: Set of documents to extract from. :param split: Which split to assign the extracted Candidates to. :param clear: Whether or not to clear the existing Candidates beforehand. :param parallelism: How many threads to use for extraction. This will override the parallelism value used to initialize the CandidateExtractor if it is provided. :param progress_bar: Whether or not to display a progress bar. The progress bar is measured per document. """ super().apply( docs, split=split, clear=clear, parallelism=parallelism, progress_bar=progress_bar, )
[docs] def clear(self, split: int) -> None: # type: ignore """Delete Candidates of each class initialized with the CandidateExtractor from given split the database. :param split: Which split to clear. """ for candidate_class in self.candidate_classes: logger.info( f"Clearing table {candidate_class.__tablename__} (split {split})" ) self.session.query(Candidate).filter( Candidate.type == candidate_class.__tablename__ ).filter(Candidate.split == split).delete(synchronize_session="fetch")
[docs] def clear_all(self, split: int) -> None: """Delete ALL Candidates from given split the database. :param split: Which split to clear. """ logger.info("Clearing ALL Candidates.") self.session.query(Candidate).filter(Candidate.split == split).delete( synchronize_session="fetch" )
[docs] def get_candidates( self, docs: Union[Document, Iterable[Document], None] = None, split: int = 0, sort: bool = False, ) -> List[List[Candidate]]: """Return a list of lists of the candidates associated with this extractor. Each list of the return will contain the candidates for one of the candidate classes associated with the CandidateExtractor. :param docs: If provided, return candidates from these documents from all splits. :param split: If docs is None, then return all the candidates from this split. :param sort: If sort is True, then return all candidates sorted by stable_id. :return: Candidates for each candidate_class. """ result = [] if docs: docs = docs if isinstance(docs, Iterable) else [docs] # Get cands from all splits for candidate_class in self.candidate_classes: cands = ( self.session.query(candidate_class) .filter(candidate_class.document_id.in_([doc.id for doc in docs])) .order_by(candidate_class.id) .all() ) if sort: cands = sorted( cands, key=lambda x: "_".join( [x[i].context.get_stable_id() for i in range(len(x))] ), ) result.append(cands) else: for candidate_class in self.candidate_classes: # Filter by candidate_ids in a particular split sub_query = ( self.session.query(Candidate.id) .filter(Candidate.split == split) .subquery() ) cands = ( self.session.query(candidate_class) .filter(candidate_class.id.in_(sub_query)) .order_by(candidate_class.id) .all() ) if sort: cands = sorted( cands, key=lambda x: "_".join( [x[i].context.get_stable_id() for i in range(len(x))] ), ) result.append(cands) return result
# Type alias for throttler Throttler = Callable[[Tuple[Mention, ...]], bool] class CandidateExtractorUDF(UDF): """UDF for performing candidate extraction.""" def __init__( self, candidate_classes: Union[Type[Candidate], List[Type[Candidate]]], throttlers: Union[Throttler, List[Throttler]], self_relations: bool, nested_relations: bool, symmetric_relations: bool, **kwargs: Any, ) -> None: """Initialize the CandidateExtractorUDF.""" self.candidate_classes = ( candidate_classes if isinstance(candidate_classes, (list, tuple)) else [candidate_classes] ) self.throttlers = ( throttlers if isinstance(throttlers, (list, tuple)) else [throttlers] ) self.nested_relations = nested_relations self.self_relations = self_relations self.symmetric_relations = symmetric_relations self.arities = [len(cclass.__argnames__) for cclass in self.candidate_classes] super().__init__(**kwargs) def apply( # type: ignore self, doc: Document, split: int, **kwargs: Any ) -> Document: """Extract candidates from the given Context. :param doc: A document to process. :param split: Which split to use. """ logger.debug(f"Document: {doc}") # Iterate over each candidate class for i, candidate_class in enumerate(self.candidate_classes): logger.debug(f" Relation: {candidate_class.__name__}") # Generates and persists candidates candidate_args = {"split": split} candidate_args["document"] = doc cands = product( *[ enumerate( # a list of mentions for each mention subclass within a doc getattr(doc, mention.__tablename__ + "s") ) for mention in candidate_class.mentions ] ) # Get a set of stable_ids of candidates. set_of_stable_ids = get_set_of_stable_ids(doc, candidate_class) for cand in cands: # Apply throttler if one was given. # Accepts a tuple of Mention objects # (throttler returns whether or not proposed candidate # passes throttling condition) if self.throttlers[i]: if not self.throttlers[i]( tuple(cand[j][1] for j in range(self.arities[i])) ): continue # TODO: Make this work for higher-order relations if self.arities[i] == 2: ai, a = (cand[0][0], cand[0][1].context) bi, b = (cand[1][0], cand[1][1].context) # Check for self-joins, "nested" joins (joins from context to # its subcontext), and flipped duplicate "symmetric" relations if not self.self_relations and a == b: logger.debug(f"Skipping self-joined candidate {cand}") continue if not self.nested_relations and (a in b or b in a): logger.debug(f"Skipping nested candidate {cand}") continue if not self.symmetric_relations and ai > bi: logger.debug(f"Skipping symmetric candidate {cand}") continue # Assemble candidate arguments for j, arg_name in enumerate(candidate_class.__argnames__): candidate_args[arg_name] = cand[j][1] stable_ids = tuple( cand[j][1].context.get_stable_id() for j in range(self.arities[i]) ) # Skip if this (temporary) candidate is used by this candidate class. if ( hasattr(doc, candidate_class.__tablename__ + "s") and stable_ids in set_of_stable_ids ): continue # Add Candidate to session candidate_class(**candidate_args) return doc