# Source code for fonduer.learning.utils

```import logging
from typing import Set, Tuple

import numpy as np
from torch.utils.data import Dataset

from fonduer.learning.models.marginal import Marginal

logger = logging.getLogger(__name__)

[docs]def save_marginals(session, X, marginals, training=True):
"""Save marginal probabilities for a set of Candidates to db.

:param X: A list of arbitrary objects with candidate ids accessible via a
.id attrib
:param marginals: A dense M x K matrix of marginal probabilities, where
K is the cardinality of the candidates, OR a M-dim list/array if K=2.
:param training: If True, these are training marginals / labels; else they
are saved as end model predictions.

Note: The marginals for k=0 are not stored, only for k = 1,...,K
"""

logger = logging.getLogger(__name__)
# Make sure that we are working with a numpy array
try:
shape = marginals.shape
except Exception:
marginals = np.array(marginals)
shape = marginals.shape

# Handle binary input as M x 1-dim array; assume elements represent
# poksitive (k=1) class values
if len(shape) == 1:
marginals = np.vstack([1 - marginals, marginals]).T

# Only add values for classes k=1,...,K
marginal_tuples = []
for i in range(shape):
for k in range(1, shape if len(shape) > 1 else 2):
if marginals[i, k] > 0:
marginal_tuples.append((i, k, marginals[i, k]))

# NOTE: This will delete all existing marginals of type `training`
session.query(Marginal).filter(Marginal.training == training).delete(
synchronize_session="fetch"
)

# Prepare bulk INSERT query
q = Marginal.__table__.insert()

# Prepare values
insert_vals = []
for i, k, p in marginal_tuples:
cid = X[i].id
insert_vals.append(
{
"candidate_id": cid,
"training": training,
"value": k,
# We cast p in case its a numpy type, which psycopg2 does not handle
"probability": float(p),
}
)

# Execute update
session.execute(q, insert_vals)
session.commit()
logger.info(f"Saved {len(marginals)} marginals")

[docs]class MultiModalDataset(Dataset):
"""A dataset contains all multimodal features in X and coressponding label Y.
"""

def __init__(self, X, Y=None):
self.X = X
self.Y = Y

def __len__(self):
return len(self.X)

def __getitem__(self, idx):
if self.Y is not None:
return self.X[idx], self.Y[idx]
return self.X[idx]

[docs]def confusion_matrix(pred: Set, gold: Set) -> Tuple[Set, Set, Set]:
"""Return a confusion matrix.

This can be used for both entity-level and mention-level

:param pred: a set of predicted entities/candidates
:type pred: set
:param gold: a set of golden entities/candidates
:type gold: set
:return: a tuple of TP, FP, and FN
:rtype: (set, set, set)
"""
if not isinstance(pred, set):
pred = set(pred)
if not isinstance(gold, set):
gold = set(gold)
TP = pred.intersection(gold)
FP = pred.difference(gold)
FN = gold.difference(pred)
return (TP, FP, FN)
```