Source code for fonduer.supervision.labeler

import logging
from collections import defaultdict
from typing import (
    Any,
    Callable,
    Collection,
    DefaultDict,
    Dict,
    Iterable,
    Iterator,
    List,
    Optional,
    Tuple,
    Type,
    Union,
)

import numpy as np
from sqlalchemy import Table
from sqlalchemy.orm import Session

from fonduer.candidates.models import Candidate
from fonduer.parser.models import Document
from fonduer.supervision.models import GoldLabelKey, Label, LabelKey
from fonduer.utils.udf import UDF, UDFRunner
from fonduer.utils.utils_udf import (
    ALL_SPLITS,
    batch_upsert_records,
    drop_all_keys,
    drop_keys,
    get_docs_from_split,
    get_mapping,
    get_sparse_matrix,
    get_sparse_matrix_keys,
    unshift_label_matrix,
    upsert_keys,
)

logger = logging.getLogger(__name__)

# Snorkel changed the label convention: ABSTAIN is now represented by -1 (used to be 0).
# Accordingly, user-defined labels should now be 0-indexed (used to be 1-indexed).
# Details can be found at https://github.com/snorkel-team/snorkel/pull/1309
ABSTAIN = -1


[docs]class Labeler(UDFRunner): """An operator to add Label Annotations to Candidates. :param session: The database session to use. :param candidate_classes: A list of candidate_subclasses to label. :param parallelism: The number of processes to use in parallel. Default 1. """ def __init__( self, session: Session, candidate_classes: List[Type[Candidate]], parallelism: int = 1, ): """Initialize the Labeler.""" super().__init__( session, LabelerUDF, parallelism=parallelism, candidate_classes=candidate_classes, ) self.candidate_classes = candidate_classes self.lfs: List[List[Callable]] = []
[docs] def update( self, docs: Collection[Document] = None, split: int = 0, lfs: List[List[Callable]] = None, parallelism: int = None, progress_bar: bool = True, table: Table = Label, ) -> None: """Update the labels of the specified candidates based on the provided LFs. :param docs: If provided, apply the updated LFs to all the candidates in these documents. :param split: If docs is None, apply the updated LFs to the candidates in this particular split. :param lfs: A list of lists of labeling functions to update. Each list should correspond with the candidate_classes used to initialize the Labeler. :param parallelism: How many threads to use for extraction. This will override the parallelism value used to initialize the Labeler if it is provided. :param progress_bar: Whether or not to display a progress bar. The progress bar is measured per document. :param table: A (database) table labels are written to. Takes `Label` (by default) or `GoldLabel`. """ if lfs is None: raise ValueError("Please provide a list of lists of labeling functions.") if len(lfs) != len(self.candidate_classes): raise ValueError("Please provide LFs for each candidate class.") self.table = table self.apply( docs=docs, split=split, lfs=lfs, train=True, clear=False, parallelism=parallelism, progress_bar=progress_bar, table=table, )
[docs] def apply( # type: ignore self, docs: Collection[Document] = None, split: int = 0, train: bool = False, lfs: List[List[Callable]] = None, clear: bool = True, parallelism: int = None, progress_bar: bool = True, table: Table = Label, ) -> None: """Apply the labels of the specified candidates based on the provided LFs. :param docs: If provided, apply the LFs to all the candidates in these documents. :param split: If docs is None, apply the LFs to the candidates in this particular split. :param train: Whether or not to update the global key set of labels and the labels of candidates. :param lfs: A list of lists of labeling functions to apply. Each list should correspond with the candidate_classes used to initialize the Labeler. :param clear: Whether or not to clear the labels table before applying these LFs. :param parallelism: How many threads to use for extraction. This will override the parallelism value used to initialize the Labeler if it is provided. :param progress_bar: Whether or not to display a progress bar. The progress bar is measured per document. :param table: A (database) table labels are written to. Takes `Label` (by default) or `GoldLabel`. :raises ValueError: If labeling functions are not provided for each candidate class. """ if lfs is None: raise ValueError("Please provide a list of labeling functions.") if len(lfs) != len(self.candidate_classes): raise ValueError("Please provide LFs for each candidate class.") self.lfs = lfs self.table = table if docs: # Call apply on the specified docs for all splits # TODO: split is int split = ALL_SPLITS # type: ignore super().apply( docs, split=split, train=train, lfs=self.lfs, clear=clear, parallelism=parallelism, progress_bar=progress_bar, table=table, ) # Needed to sync the bulk operations self.session.commit() else: # Only grab the docs containing candidates from the given split. split_docs = get_docs_from_split( self.session, self.candidate_classes, split ) super().apply( split_docs, split=split, train=train, lfs=self.lfs, clear=clear, parallelism=parallelism, progress_bar=progress_bar, table=table, ) # Needed to sync the bulk operations self.session.commit()
[docs] def get_keys(self) -> List[LabelKey]: """Return a list of keys for the Labels. :return: List of LabelKeys. """ return list(get_sparse_matrix_keys(self.session, LabelKey))
[docs] def upsert_keys( self, keys: Iterable[Union[str, Callable]], candidate_classes: Optional[ Union[Type[Candidate], List[Type[Candidate]]] ] = None, ) -> None: """Upsert the specified keys from LabelKeys. :param keys: A list of labeling functions to upsert. :param candidate_classes: A list of the Candidates to upsert the key for. If None, upsert the keys for all candidate classes associated with this Labeler. """ # Make sure keys is iterable keys = keys if isinstance(keys, (list, tuple)) else [keys] # Make sure candidate_classes is iterable if candidate_classes: candidate_classes = ( candidate_classes if isinstance(candidate_classes, (list, tuple)) else [candidate_classes] ) # Ensure only candidate classes associated with the labeler are used. candidate_classes = [ _.__tablename__ for _ in candidate_classes if _ in self.candidate_classes ] if len(candidate_classes) == 0: logger.warning( "You didn't specify valid candidate classes for this Labeler." ) return # If unspecified, just use all candidate classes else: candidate_classes = [_.__tablename__ for _ in self.candidate_classes] # build dict for use by utils key_map = dict() for key in keys: # Assume key is an LF if hasattr(key, "__name__"): key_map[key.__name__] = set(candidate_classes) elif hasattr(key, "name"): key_map[key.name] = set(candidate_classes) else: key_map[key] = set(candidate_classes) upsert_keys(self.session, LabelKey, key_map)
[docs] def drop_keys( self, keys: Iterable[Union[str, Callable]], candidate_classes: Optional[ Union[Type[Candidate], List[Type[Candidate]]] ] = None, ) -> None: """Drop the specified keys from LabelKeys. :param keys: A list of labeling functions to delete. :param candidate_classes: A list of the Candidates to drop the key for. If None, drops the keys for all candidate classes associated with this Labeler. """ # Make sure keys is iterable keys = keys if isinstance(keys, (list, tuple)) else [keys] # Make sure candidate_classes is iterable if candidate_classes: candidate_classes = ( candidate_classes if isinstance(candidate_classes, (list, tuple)) else [candidate_classes] ) # Ensure only candidate classes associated with the labeler are used. candidate_classes = [ _.__tablename__ for _ in candidate_classes if _ in self.candidate_classes ] if len(candidate_classes) == 0: logger.warning( "You didn't specify valid candidate classes for this Labeler." ) return # If unspecified, just use all candidate classes else: candidate_classes = [_.__tablename__ for _ in self.candidate_classes] # build dict for use by utils key_map = dict() for key in keys: # Assume key is an LF if hasattr(key, "__name__"): key_map[key.__name__] = set(candidate_classes) elif hasattr(key, "name"): key_map[key.name] = set(candidate_classes) else: key_map[key] = set(candidate_classes) drop_keys(self.session, LabelKey, key_map)
def _add(self, records_list: List[List[Dict[str, Any]]]) -> None: for records in records_list: batch_upsert_records(self.session, self.table, records)
[docs] def clear( # type: ignore self, train: bool, split: int, lfs: Optional[List[List[Callable]]] = None, table: Table = Label, **kwargs: Any, ) -> None: """Delete Labels of each class from the database. :param train: Whether or not to clear the LabelKeys. :param split: Which split of candidates to clear labels from. :param lfs: This parameter is ignored. :param table: A (database) table labels are cleared from. Takes `Label` (by default) or `GoldLabel`. """ # Clear Labels for the candidates in the split passed in. logger.info(f"Clearing Labels (split {split})") if split == ALL_SPLITS: sub_query = self.session.query(Candidate.id).subquery() else: sub_query = ( self.session.query(Candidate.id) .filter(Candidate.split == split) .subquery() ) query = self.session.query(table).filter(table.candidate_id.in_(sub_query)) query.delete(synchronize_session="fetch") # Delete all old annotation keys if train: key_table = LabelKey if table == Label else GoldLabelKey logger.debug( f"Clearing all {key_table.__name__}s from {self.candidate_classes}..." ) drop_all_keys(self.session, key_table, self.candidate_classes)
[docs] def clear_all(self, table: Table = Label) -> None: """Delete all Labels. :param table: A (database) table labels are cleared from. Takes `Label` (by default) or `GoldLabel`. """ key_table = LabelKey if table == Label else GoldLabelKey logger.info(f"Clearing ALL {table.__name__}s and {key_table.__name__}s.") self.session.query(table).delete(synchronize_session="fetch") self.session.query(key_table).delete(synchronize_session="fetch")
def _after_apply( self, train: bool = False, table: Table = Label, **kwargs: Any ) -> None: # Insert all Label Keys if train: key_map: DefaultDict[str, set] = defaultdict(set) for label in self.session.query(table).all(): cand = label.candidate for key in label.keys: key_map[key].add(cand.__class__.__tablename__) key_table = LabelKey if table == Label else GoldLabelKey self.session.query(key_table).delete(synchronize_session="fetch") # TODO: upsert is too much. insert is fine as all keys are deleted. upsert_keys(self.session, key_table, key_map)
[docs] def get_gold_labels( self, cand_lists: List[List[Candidate]], annotator: Optional[str] = None ) -> List[np.ndarray]: """Load dense matrix of GoldLabels for each candidate_class. :param cand_lists: The candidates to get gold labels for. :param annotator: A specific annotator key to get labels for. Default None. :raises ValueError: If get_gold_labels is called before gold labels are loaded, the result will contain ABSTAIN values. We raise a ValueError to help indicate this potential mistake to the user. :return: A list of MxN dense matrix where M are the candidates and N is the annotators. If annotator is provided, return a list of Mx1 matrix. """ gold_labels = [ unshift_label_matrix(m) for m in get_sparse_matrix( self.session, GoldLabelKey, cand_lists, key=annotator ) ] for cand_labels in gold_labels: if ABSTAIN in cand_labels: raise ValueError( "Gold labels contain ABSTAIN labels. " "Did you load gold labels beforehand?" ) return gold_labels
[docs] def get_label_matrices(self, cand_lists: List[List[Candidate]]) -> List[np.ndarray]: """Load dense matrix of Labels for each candidate_class. :param cand_lists: The candidates to get labels for. :return: A list of MxN dense matrix where M are the candidates and N is the labeling functions. """ return [ unshift_label_matrix(m) for m in get_sparse_matrix(self.session, LabelKey, cand_lists) ]
class LabelerUDF(UDF): """UDF for performing candidate extraction.""" def __init__( self, candidate_classes: Union[Type[Candidate], List[Type[Candidate]]], **kwargs: Any, ): """Initialize the LabelerUDF.""" self.candidate_classes = ( candidate_classes if isinstance(candidate_classes, (list, tuple)) else [candidate_classes] ) super().__init__(**kwargs) def _f_gen(self, c: Candidate) -> Iterator[Tuple[int, str, int]]: """Convert lfs into a generator of id, name, and labels. In particular, catch verbose values and convert to integer ones. """ lf_idx = self.candidate_classes.index(c.__class__) labels = lambda c: [ ( c.id, lf.__name__ if hasattr(lf, "__name__") else lf.name, # type: ignore lf(c), ) for lf in self.lfs[lf_idx] ] for cid, lf_key, label in labels(c): # Note: We assume if the LF output is an int, it is already # mapped correctly if isinstance(label, int): yield cid, lf_key, label + 1 # convert to {0, 1, ..., k} # None is a protected LF output value corresponding to ABSTAIN, # representing LF abstaining elif label is None: yield cid, lf_key, ABSTAIN + 1 # convert to {0, 1, ..., k} elif label in c.values: # convert to {0, 1, ..., k} yield cid, lf_key, c.values.index(label) + 1 else: raise ValueError( f"Can't parse label value {label} for candidate values {c.values}" ) def apply( # type: ignore self, doc: Document, lfs: List[List[Callable]], table: Table = Label, **kwargs: Any, ) -> List[List[Dict[str, Any]]]: """Extract candidates from the given Context. :param doc: A document to process. :param lfs: The list of functions to use to generate labels. """ logger.debug(f"Document: {doc}") if lfs is None: raise ValueError("Must provide lfs kwarg.") self.lfs = lfs # Get all the candidates in this doc that will be labeled cands_list = [ getattr(doc, candidate_class.__tablename__ + "s") for candidate_class in self.candidate_classes ] records_list = [ list(get_mapping(table, cands, self._f_gen)) for cands in cands_list ] return records_list