View on GitHub

✂️ Snorkel Intro Tutorial: Data Slicing

In real-world applications, some model outcomes are often more important than others — e.g. vulnerable cyclist detections in an autonomous driving task, or, in our running spam application, potentially malicious link redirects to external websites.

Traditional machine learning systems optimize for overall quality, which may be too coarse-grained. Models that achieve high overall performance might produce unacceptable failure rates on critical slices of the data — data subsets that might correspond to vulnerable cyclist detection in an autonomous driving task, or in our running spam detection application, external links to potentially malicious websites.

In this tutorial, we:

  1. Introduce Slicing Functions (SFs) as a programming interface
  2. Monitor application-critical data subsets
  3. Improve model performance on slices

Note: this tutorial differs from the labeling tutorial in that we use ground truth labels in the train split for demo purposes. SFs are intended to be used after the training set has already been labeled by LFs (or by hand) in the training data pipeline.

from utils import load_spam_dataset

df_train, df_test = load_spam_dataset(load_train_labels=True)

1. Write slicing functions

We leverage slicing functions (SFs), which output binary masks indicating whether an data point is in the slice or not. Each slice represents some noisily-defined subset of the data (corresponding to an SF) that we’d like to programmatically monitor.

In the following cells, we use the @slicing_function() decorator to initialize an SF that identifies short comments

You’ll notice that the short_comment SF is a heuristic, like the other programmatic ops we’ve defined, and may not fully cover the slice of interest. That’s okay — in last section, we’ll show how a model can handle this in Snorkel.

import re
from snorkel.slicing import slicing_function


@slicing_function()
def short_comment(x):
    """Ham comments are often short, such as 'cool video!'"""
    return len(x.text.split()) < 5


sfs = [short_comment]

Visualize slices

With a utility function, slice_dataframe, we can visualize data points belonging to this slice in a pandas.DataFrame.

from snorkel.slicing import slice_dataframe

short_comment_df = slice_dataframe(df_test, short_comment)
short_comment_df[["text", "label"]].head()
text label
194 super music 0
2 I like shakira.. 0
110 subscribe to my feed 1
263 Awesome  0
77 Nice 0

2. Monitor slice performance with Scorer.score_slices

In this section, we’ll demonstrate how we might monitor slice performance on the short_comment slice — this approach is compatible with any modeling framework.

Train a simple classifier

First, we featurize the data — as you saw in the introductory Spam tutorial, we can extract simple bag-of-words features and store them as numpy arrays.

from sklearn.feature_extraction.text import CountVectorizer
from utils import df_to_features

vectorizer = CountVectorizer(ngram_range=(1, 1))
X_train, Y_train = df_to_features(vectorizer, df_train, "train")
X_test, Y_test = df_to_features(vectorizer, df_test, "test")

We define a LogisticRegression model from sklearn.

from sklearn.linear_model import LogisticRegression

sklearn_model = LogisticRegression(C=0.001, solver="liblinear")
sklearn_model.fit(X=X_train, y=Y_train)
LogisticRegression(C=0.001, class_weight=None, dual=False, fit_intercept=True,
                   intercept_scaling=1, l1_ratio=None, max_iter=100,
                   multi_class='auto', n_jobs=None, penalty='l2',
                   random_state=None, solver='liblinear', tol=0.0001, verbose=0,
                   warm_start=False)
from snorkel.utils import preds_to_probs

preds_test = sklearn_model.predict(X_test)
probs_test = preds_to_probs(preds_test, 2)
from sklearn.metrics import f1_score

print(f"Test set F1: {100 * f1_score(Y_test, preds_test):.1f}%")
Test set F1: 92.5%

Store slice metadata in S

We apply our list of sfs to the data using an SF applier. For our data format, we leverage the PandasSFApplier. The output of the applier is an np.recarray which stores vectors in named fields indicating whether each of $n$ data points belongs to the corresponding slice.

from snorkel.slicing import PandasSFApplier

applier = PandasSFApplier(sfs)
S_test = applier.apply(df_test)

Now, we initialize a Scorer using the desired metrics.

from snorkel.analysis import Scorer

scorer = Scorer(metrics=["f1"])

Using the score_slices method, we can see both overall and slice-specific performance.

scorer.score_slices(
    S=S_test, golds=Y_test, preds=preds_test, probs=probs_test, as_dataframe=True
)
f1
overall 0.925000
short_comment 0.666667

Despite high overall performance, the short_comment slice performs poorly here!

Write additional slicing functions (SFs)

Slices are dynamic — as monitoring needs grow or change with new data distributions or application needs, an ML pipeline might require dozens, or even hundreds, of slices.

We’ll take inspiration from the labeling tutorial to write additional slicing functions. We demonstrate how the same powerful preprocessors and utilities available for labeling functions can be leveraged for slicing functions.

from snorkel.slicing import SlicingFunction, slicing_function
from snorkel.preprocess import preprocessor


# Keyword-based SFs
def keyword_lookup(x, keywords):
    return any(word in x.text.lower() for word in keywords)


def make_keyword_sf(keywords):
    return SlicingFunction(
        name=f"keyword_{keywords[0]}",
        f=keyword_lookup,
        resources=dict(keywords=keywords),
    )


keyword_please = make_keyword_sf(keywords=["please", "plz"])


# Regex-based SFs
@slicing_function()
def regex_check_out(x):
    return bool(re.search(r"check.*out", x.text, flags=re.I))


@slicing_function()
def short_link(x):
    """Returns whether text matches common pattern for shortened ".ly" links."""
    return bool(re.search(r"\w+\.ly", x.text))


# Leverage preprocessor in SF
from textblob import TextBlob


@preprocessor(memoize=True)
def textblob_sentiment(x):
    scores = TextBlob(x.text)
    x.polarity = scores.sentiment.polarity
    return x


@slicing_function(pre=[textblob_sentiment])
def textblob_polarity(x):
    return x.polarity > 0.9

Again, we’d like to visualize data points in a particular slice. This time, we’ll inspect the textblob_polarity slice.

Most data points with high-polarity sentiments are strong opinions about the video — hence, they are usually relevant to the video, and the corresponding labels are $0$ (not spam). We might define a slice here for product and marketing reasons, it’s important to make sure that we don’t misclassify very positive comments from good users.

polarity_df = slice_dataframe(df_test, textblob_polarity)
polarity_df[["text", "label"]].head()
text label
263 Awesome  0
240 Shakira is the best dancer 0
261 OMG LISTEN TO THIS ITS SOO GOOD!! :D 0
14 Shakira is very beautiful 0
114 awesome 0

We can evaluate performance on all SFs using the model-agnostic Scorer.

extra_sfs = [keyword_please, regex_check_out, short_link, textblob_polarity]

sfs = [short_comment] + extra_sfs
slice_names = [sf.name for sf in sfs]

Let’s see how the sklearn model we learned before performs on these new slices.

applier = PandasSFApplier(sfs)
S_test = applier.apply(df_test)
scorer.score_slices(
    S=S_test, golds=Y_test, preds=preds_test, probs=probs_test, as_dataframe=True
)
f1
overall 0.925000
short_comment 0.666667
keyword_please 1.000000
regex_check_out 1.000000
short_link 0.500000
textblob_polarity 0.727273

Looks like some do extremely well on our small test set, while others do decently. At the very least, we may want to monitor these to make sure that as we iterate to improve certain slices like short_comment, we don’t hurt the performance of others. Next, we’ll introduce a model that helps us to do this balancing act automatically.

3. Improve slice performance

In the following section, we demonstrate a modeling approach that we call Slice-based Learning, which improves performance by adding extra slice-specific representational capacity to whichever model we’re using. Intuitively, we’d like to model to learn representations that are better suited to handle data points in this slice. In our approach, we model each slice as a separate “expert task” in the style of multi-task learning; for further details of how slice-based learning works under the hood, check out the code (with paper coming soon)!

In other approaches, one might attempt to increase slice performance with techniques like oversampling (i.e. with PyTorch’s WeightedRandomSampler), effectively shifting the training distribution towards certain populations.

This might work with small number of slices, but with hundreds or thousands or production slices at scale, it could quickly become intractable to tune upsampling weights per slice.

Constructing a SliceAwareClassifier

To cope with scale, we will attempt to learn and combine many slice-specific representations with an attention mechanism. (Please see our Section 3 of our technical report for details on this approach).

First we’ll initialize a SliceAwareClassifier:

  • base_architecture: We define a simple Multi-Layer Perceptron (MLP) in Pytorch to serve as the primary representation architecture. We note that the BinarySlicingClassifier is agnostic to the base architecture — you might leverage a Transformer model for text, or a ResNet for images.
  • head_dim: identifies the final output feature dimension of the base_architecture
  • slice_names: Specify the slices that we plan to train on with this classifier.
from snorkel.slicing import SliceAwareClassifier
from utils import get_pytorch_mlp

# Define model architecture
bow_dim = X_train.shape[1]
hidden_dim = bow_dim
mlp = get_pytorch_mlp(hidden_dim=hidden_dim, num_layers=2)

# Initialize slice model
slice_model = SliceAwareClassifier(
    base_architecture=mlp,
    head_dim=hidden_dim,
    slice_names=[sf.name for sf in sfs],
    scorer=scorer,
)

Next, we’ll generate the remaining S matrixes with the new set of slicing functions.

applier = PandasSFApplier(sfs)
S_train = applier.apply(df_train)
S_test = applier.apply(df_test)

In order to train using slice information, we’d like to initialize a slice-aware dataloader. To do this, we can use slice_model.make_slice_dataloader to add slice labels to an existing dataloader.

Under the hood, this method leverages slice metadata to add slice labels to the appropriate fields such that it will be compatible with our model, a SliceAwareClassifier.

from utils import create_dict_dataloader

BATCH_SIZE = 64

train_dl = create_dict_dataloader(X_train, Y_train, "train")
train_dl_slice = slice_model.make_slice_dataloader(
    train_dl.dataset, S_train, shuffle=True, batch_size=BATCH_SIZE
)
test_dl = create_dict_dataloader(X_test, Y_test, "train")
test_dl_slice = slice_model.make_slice_dataloader(
    test_dl.dataset, S_test, shuffle=False, batch_size=BATCH_SIZE
)

Representation learning with slices

Using Snorkel’s Trainer, we fit our classifier with the training set dataloader.

from snorkel.classification import Trainer

# For demonstration purposes, we set n_epochs=2
trainer = Trainer(n_epochs=2, lr=1e-4, progress_bar=True)
trainer.fit(slice_model, [train_dl_slice])
Epoch 0:: 100%|██████████| 25/25 [00:24<00:00,  1.01it/s, model/all/train/loss=0.5, model/all/train/lr=0.0001]
Epoch 1:: 100%|██████████| 25/25 [00:25<00:00,  1.01s/it, model/all/train/loss=0.257, model/all/train/lr=0.0001]

At inference time, the primary task head (spam_task) will make all final predictions. We’d like to evaluate all the slice heads on the original task head — score_slices remaps all slice-related labels, denoted spam_task_slice:{slice_name}_pred, to be evaluated on the spam_task.

slice_model.score_slices([test_dl_slice], as_dataframe=True)
label dataset split metric score
0 task SnorkelDataset train f1 0.941704
1 task_slice:short_comment_pred SnorkelDataset train f1 0.769231
2 task_slice:keyword_please_pred SnorkelDataset train f1 0.977778
3 task_slice:regex_check_out_pred SnorkelDataset train f1 1.000000
4 task_slice:short_link_pred SnorkelDataset train f1 0.500000
5 task_slice:textblob_polarity_pred SnorkelDataset train f1 0.800000
6 task_slice:base_pred SnorkelDataset train f1 0.941704

Note: in this toy dataset, we see high variance in slice performance, because our dataset is so small that (i) there are few data points in the train split, giving little signal to learn over, and (ii) there are few data points in the test split, making our evaluation metrics very noisy. For a demonstration of data slicing deployed in state-of-the-art models, please see our SuperGLUE tutorials.


Recap

This tutorial walked through the process authoring slices, monitoring model performance on specific slices, and improving model performance using slice information. This programming abstraction provides a mechanism to heuristically identify critical data subsets. For more technical details about Slice-based Learning, please see our NeurIPS 2019 paper!