View on GitHub

✂️ Snorkel Intro Tutorial: Data Slicing

In real-world applications, some model outcomes are often more important than others — e.g. vulnerable cyclist detections in an autonomous driving task, or, in our running spam application, potentially malicious link redirects to external websites.

Traditional machine learning systems optimize for overall quality, which may be too coarse-grained. Models that achieve high overall performance might produce unacceptable failure rates on critical slices of the data — data subsets that might correspond to vulnerable cyclist detection in an autonomous driving task, or in our running spam detection application, external links to potentially malicious websites.

In this tutorial, we:

  1. Introduce Slicing Functions (SFs) as a programming interface
  2. Monitor application-critical data subsets
  3. Improve model performance on slices

Note: this tutorial differs from the labeling tutorial in that we use ground truth labels in the train split for demo purposes. SFs are intended to be used after the training set has already been labeled by LFs (or by hand) in the training data pipeline.

from utils import load_spam_dataset

df_train, df_valid, df_test = load_spam_dataset(load_train_labels=True, split_dev=False)

1. Write slicing functions

We leverage slicing functions (SFs), which output binary masks indicating whether an data point is in the slice or not. Each slice represents some noisily-defined subset of the data (corresponding to an SF) that we’d like to programmatically monitor.

In the following cells, we use the @slicing_function() decorator to initialize an SF that identifies shortened links the spam dataset. These links could redirect us to potentially dangerous websites, and we don’t want our users to click them! To select the subset of shortened links in our dataset, we write a regex that checks for the commonly-used .ly extension.

You’ll notice that the short_link SF is a heuristic, like the other programmatic ops we’ve defined, and may not fully cover the slice of interest. That’s okay — in last section, we’ll show how a model can handle this in Snorkel.

import re
from snorkel.slicing import slicing_function


@slicing_function()
def short_link(x):
    """Returns whether text matches common pattern for shortened ".ly" links."""
    return bool(re.search(r"\w+\.ly", x.text))


sfs = [short_link]

Visualize slices

With a utility function, slice_dataframe, we can visualize data points belonging to this slice in a pandas.DataFrame.

from snorkel.slicing import slice_dataframe

short_link_df = slice_dataframe(df_valid, short_link)
short_link_df[["text", "label"]]
text label
280 Being paid to respond to fast paid surveys fro... 1
192 Meet The Richest Online Marketer NOW CLICK : ... 1
301 coby this USL and past :<br /><a href="http://... 1
350 adf.ly / KlD3Y 1
18 Earn money for being online with 0 efforts! ... 1

2. Monitor slice performance with Scorer.score_slices

In this section, we’ll demonstrate how we might monitor slice performance on the short_link slice — this approach is compatible with any modeling framework.

Train a simple classifier

First, we featurize the data — as you saw in the introductory Spam tutorial, we can extract simple bag-of-words features and store them as numpy arrays.

from sklearn.feature_extraction.text import CountVectorizer
from utils import df_to_features

vectorizer = CountVectorizer(ngram_range=(1, 1))
X_train, Y_train = df_to_features(vectorizer, df_train, "train")
X_valid, Y_valid = df_to_features(vectorizer, df_valid, "valid")
X_test, Y_test = df_to_features(vectorizer, df_test, "test")

We define a LogisticRegression model from sklearn and show how we might visualize these slice-specific scores.

from sklearn.linear_model import LogisticRegression

sklearn_model = LogisticRegression(C=0.001, solver="liblinear")
sklearn_model.fit(X=X_train, y=Y_train)
print(f"Test set accuracy: {100 * sklearn_model.score(X_test, Y_test):.1f}%")
Test set accuracy: 92.8%
from snorkel.utils import preds_to_probs

preds_test = sklearn_model.predict(X_test)
probs_test = preds_to_probs(preds_test, 2)

Store slice metadata in S

We apply our list of sfs to the data using an SF applier. For our data format, we leverage the PandasSFApplier. The output of the applier is an np.recarray which stores vectors in named fields indicating whether each of $n$ data points belongs to the corresponding slice.

from snorkel.slicing import PandasSFApplier

applier = PandasSFApplier(sfs)
S_test = applier.apply(df_test)

Now, we initialize a Scorer using the desired metrics.

from snorkel.analysis import Scorer

scorer = Scorer(metrics=["accuracy", "f1"])

Using the score_slices method, we can see both overall and slice-specific performance.

scorer.score_slices(
    S=S_test, golds=Y_test, preds=preds_test, probs=probs_test, as_dataframe=True
)
accuracy f1
overall 0.928000 0.925
short_link 0.333333 0.500

Despite high overall performance, the short_link slice performs poorly here!

Write additional slicing functions (SFs)

Slices are dynamic — as monitoring needs grow or change with new data distributions or application needs, an ML pipeline might require dozens, or even hundreds, of slices.

We’ll take inspiration from the labeling tutorial to write additional slicing functions. We demonstrate how the same powerful preprocessors and utilities available for labeling functions can be leveraged for slicing functions.

from snorkel.slicing import SlicingFunction, slicing_function
from snorkel.preprocess import preprocessor


# Keyword-based SFs
def keyword_lookup(x, keywords):
    return any(word in x.text.lower() for word in keywords)


def make_keyword_sf(keywords):
    return SlicingFunction(
        name=f"keyword_{keywords[0]}",
        f=keyword_lookup,
        resources=dict(keywords=keywords),
    )


keyword_subscribe = make_keyword_sf(keywords=["subscribe"])
keyword_please = make_keyword_sf(keywords=["please", "plz"])


# Regex-based SF
@slicing_function()
def regex_check_out(x):
    return bool(re.search(r"check.*out", x.text, flags=re.I))


# Heuristic-based SF
@slicing_function()
def short_comment(x):
    """Ham comments are often short, such as 'cool video!'"""
    return len(x.text.split()) < 5


# Leverage preprocessor in SF
from textblob import TextBlob


@preprocessor(memoize=True)
def textblob_sentiment(x):
    scores = TextBlob(x.text)
    x.polarity = scores.sentiment.polarity
    return x


@slicing_function(pre=[textblob_sentiment])
def textblob_polarity(x):
    return x.polarity > 0.9

Again, we’d like to visualize data points in a particular slice. This time, we’ll inspect the textblob_polarity slice.

Most data points with high-polarity sentiments are strong opinions about the video — hence, they are usually relevant to the video, and the corresponding labels are $0$. We might define a slice here for product and marketing reasons, it’s important to make sure that we don’t misclassify very positive comments from good users.

polarity_df = slice_dataframe(df_valid, textblob_polarity)
polarity_df[["text", "label"]].head()
text label
16 Love this song !!!!!! 0
309 One of the best song of all the time 0
164 She is perfect 0
310 Best world cup offical song 0
352 I remember this :D 0

We can evaluate performance on all SFs using the model-agnostic Scorer.

extra_sfs = [
    keyword_subscribe,
    keyword_please,
    regex_check_out,
    short_comment,
    textblob_polarity,
]

sfs = [short_link] + extra_sfs
slice_names = [sf.name for sf in sfs]

Let’s see how the sklearn model we learned before performs on these new slices!

applier = PandasSFApplier(sfs)
S_test = applier.apply(df_test)
scorer.score_slices(
    S=S_test, golds=Y_test, preds=preds_test, probs=probs_test, as_dataframe=True
)
accuracy f1
overall 0.928000 0.925000
short_link 0.333333 0.500000
keyword_subscribe 0.944444 0.971429
keyword_please 1.000000 1.000000
regex_check_out 1.000000 1.000000
short_comment 0.945652 0.666667
textblob_polarity 0.875000 0.727273

Looks like some do extremely well on our small test set, while others do decently. At the very least, we may want to monitor these to make sure that as we iterate to improve certain slices like short_link, we don’t hurt the performance of others. Next, we’ll introduce a model that helps us to do this balancing act automatically!

3. Improve slice performance

In the following section, we demonstrate a modeling approach that we call Slice-based Learning, which improves performance by adding extra slice-specific representational capacity to whichever model we’re using. Intuitively, we’d like to model to learn representations that are better suited to handle data points in this slice. In our approach, we model each slice as a separate “expert task” in the style of multi-task learning; for further details of how slice-based learning works under the hood, check out the code (with paper coming soon)!

In other approaches, one might attempt to increase slice performance with techniques like oversampling (i.e. with PyTorch’s WeightedRandomSampler), effectively shifting the training distribution towards certain populations.

This might work with small number of slices, but with hundreds or thousands or production slices at scale, it could quickly become intractable to tune upsampling weights per slice.

Set up modeling pipeline with SlicingClassifier

Snorkel supports performance monitoring on slices using discriminative models from snorkel.slicing. To demonstrate this functionality, we’ll first set up a the datasets + modeling pipeline in the PyTorch-based snorkel.classification package.

First, we initialize a dataloaders for each split.

from utils import create_dict_dataloader

BATCH_SIZE = 64


train_dl = create_dict_dataloader(
    X_train, Y_train, "train", batch_size=BATCH_SIZE, shuffle=True
)
valid_dl = create_dict_dataloader(
    X_valid, Y_valid, "valid", batch_size=BATCH_SIZE, shuffle=False
)
test_dl = create_dict_dataloader(
    X_test, Y_test, "test", batch_size=BATCH_SIZE, shuffle=True
)

We’ll now initialize a SlicingClassifier:

  • base_architecture: We define a simple Multi-Layer Perceptron (MLP) in Pytorch to serve as the primary representation architecture. We note that the BinarySlicingClassifier is agnostic to the base architecture — you might leverage a Transformer model for text, or a ResNet for images.
  • head_dim: identifies the final output feature dimension of the base_architecture
  • slice_names: Specify the slices that we plan to train on with this classifier.
from snorkel.slicing import SlicingClassifier
from utils import get_pytorch_mlp


# Define model architecture
bow_dim = X_train.shape[1]
hidden_dim = bow_dim
mlp = get_pytorch_mlp(hidden_dim=hidden_dim, num_layers=2)

# Init slice model
slice_model = SlicingClassifier(
    base_architecture=mlp, head_dim=hidden_dim, slice_names=[sf.name for sf in sfs]
)

Monitor slice performance during training

Using Snorkel’s Trainer, we fit to train_dl, and validate on valid_dl.

We note that we can monitor slice-specific performance during training — this is a powerful way to track especially critical subsets of the data. If logging in Tensorboard (i.e. snorkel.classification.TensorboardWritier), we would visualize individual loss curves and validation metrics to debug convegence for specific slices.

from snorkel.classification import Trainer

# For demonstration purposes, we set n_epochs=2
trainer = Trainer(lr=1e-4, n_epochs=2)
trainer.fit(slice_model, [train_dl, valid_dl])
Epoch 0:: 100%|██████████| 25/25 [00:39<00:00,  1.58s/it, model/all/train/loss=0.472, model/all/train/lr=0.0001, task/SnorkelDataset/valid/accuracy=0.908, task/SnorkelDataset/valid/f1=0.893]
Epoch 1:: 100%|██████████| 25/25 [00:39<00:00,  1.60s/it, model/all/train/loss=0.0931, model/all/train/lr=0.0001, task/SnorkelDataset/valid/accuracy=0.933, task/SnorkelDataset/valid/f1=0.926]

Representation learning with slices

To cope with scale, we will attempt to learn and combine many slice-specific representations with an attention mechanism. (For details about this approach, please see our technical report — coming soon!)

First, we’ll generate the remaining S matrixes with the new set of slicing functions.

applier = PandasSFApplier(sfs)
S_train = applier.apply(df_train)
S_valid = applier.apply(df_valid)

In order to train using slice information, we’d like to initialize a slice-aware dataloader. To do this, we can use slice_model.make_slice_dataloader to add slice labels to an existing dataloader.

Under the hood, this method leverages slice metadata to add slice labels to the appropriate fields such that it’s compatible with the initialized SliceClassifier.

train_dl_slice = slice_model.make_slice_dataloader(
    train_dl.dataset, S_train, shuffle=True, batch_size=BATCH_SIZE
)
valid_dl_slice = slice_model.make_slice_dataloader(
    valid_dl.dataset, S_valid, shuffle=False, batch_size=BATCH_SIZE
)
test_dl_slice = slice_model.make_slice_dataloader(
    test_dl.dataset, S_test, shuffle=False, batch_size=BATCH_SIZE
)

We train a single model initialized with all slice tasks.

from snorkel.classification import Trainer

# For demonstration purposes, we set n_epochs=2
trainer = Trainer(n_epochs=2, lr=1e-4, progress_bar=True)
trainer.fit(slice_model, [train_dl_slice, valid_dl_slice])
Epoch 0::  96%|█████████▌| 24/25 [00:41<00:01,  1.79s/it, model/all/train/loss=0.376, model/all/train/lr=0.0001]/home/ubuntu/snorkel-tutorials/.tox/spam/lib/python3.6/site-packages/sklearn/metrics/classification.py:1437: UndefinedMetricWarning: F-score is ill-defined and being set to 0.0 due to no predicted samples.
  'precision', 'predicted', average, warn_for)
Epoch 0:: 100%|██████████| 25/25 [00:43<00:00,  1.73s/it, model/all/train/loss=0.371, model/all/train/lr=0.0001, task/SnorkelDataset/valid/accuracy=0.933, task/SnorkelDataset/valid/f1=0.926, task_slice:short_link_ind/SnorkelDataset/valid/f1=0, task_slice:short_link_pred/SnorkelDataset/valid/accuracy=0.8, task_slice:short_link_pred/SnorkelDataset/valid/f1=0.889, task_slice:keyword_subscribe_ind/SnorkelDataset/valid/f1=0, task_slice:keyword_subscribe_pred/SnorkelDataset/valid/accuracy=1, task_slice:keyword_subscribe_pred/SnorkelDataset/valid/f1=1, task_slice:keyword_please_ind/SnorkelDataset/valid/f1=0, task_slice:keyword_please_pred/SnorkelDataset/valid/accuracy=1, task_slice:keyword_please_pred/SnorkelDataset/valid/f1=1, task_slice:regex_check_out_ind/SnorkelDataset/valid/f1=0.471, task_slice:regex_check_out_pred/SnorkelDataset/valid/accuracy=1, task_slice:regex_check_out_pred/SnorkelDataset/valid/f1=1, task_slice:short_comment_ind/SnorkelDataset/valid/f1=0, task_slice:short_comment_pred/SnorkelDataset/valid/accuracy=0.947, task_slice:short_comment_pred/SnorkelDataset/valid/f1=0.5, task_slice:textblob_polarity_ind/SnorkelDataset/valid/f1=0, task_slice:textblob_polarity_pred/SnorkelDataset/valid/accuracy=1, task_slice:textblob_polarity_pred/SnorkelDataset/valid/f1=1, task_slice:base_ind/SnorkelDataset/valid/f1=1, task_slice:base_pred/SnorkelDataset/valid/accuracy=0.933, task_slice:base_pred/SnorkelDataset/valid/f1=0.926]
Epoch 1:: 100%|██████████| 25/25 [00:47<00:00,  1.88s/it, model/all/train/loss=0.17, model/all/train/lr=0.0001, task/SnorkelDataset/valid/accuracy=0.925, task/SnorkelDataset/valid/f1=0.914, task_slice:short_link_ind/SnorkelDataset/valid/f1=0, task_slice:short_link_pred/SnorkelDataset/valid/accuracy=0.2, task_slice:short_link_pred/SnorkelDataset/valid/f1=0.333, task_slice:keyword_subscribe_ind/SnorkelDataset/valid/f1=0.333, task_slice:keyword_subscribe_pred/SnorkelDataset/valid/accuracy=1, task_slice:keyword_subscribe_pred/SnorkelDataset/valid/f1=1, task_slice:keyword_please_ind/SnorkelDataset/valid/f1=0.5, task_slice:keyword_please_pred/SnorkelDataset/valid/accuracy=1, task_slice:keyword_please_pred/SnorkelDataset/valid/f1=1, task_slice:regex_check_out_ind/SnorkelDataset/valid/f1=0.791, task_slice:regex_check_out_pred/SnorkelDataset/valid/accuracy=1, task_slice:regex_check_out_pred/SnorkelDataset/valid/f1=1, task_slice:short_comment_ind/SnorkelDataset/valid/f1=0, task_slice:short_comment_pred/SnorkelDataset/valid/accuracy=0.947, task_slice:short_comment_pred/SnorkelDataset/valid/f1=0.5, task_slice:textblob_polarity_ind/SnorkelDataset/valid/f1=0, task_slice:textblob_polarity_pred/SnorkelDataset/valid/accuracy=1, task_slice:textblob_polarity_pred/SnorkelDataset/valid/f1=1, task_slice:base_ind/SnorkelDataset/valid/f1=1, task_slice:base_pred/SnorkelDataset/valid/accuracy=0.908, task_slice:base_pred/SnorkelDataset/valid/f1=0.893]

At inference time, the primary task head (spam_task) will make all final predictions. We’d like to evaluate all the slice heads on the original task head — score_slices remaps all slice-related labels, denoted spam_task_slice:{slice_name}_pred, to be evaluated on the spam_task.

slice_model.score_slices([valid_dl_slice, test_dl_slice], as_dataframe=True)
label dataset split metric score
0 task SnorkelDataset valid accuracy 0.925000
1 task SnorkelDataset valid f1 0.914286
2 task_slice:short_link_pred SnorkelDataset valid accuracy 0.400000
3 task_slice:short_link_pred SnorkelDataset valid f1 0.571429
4 task_slice:keyword_subscribe_pred SnorkelDataset valid accuracy 1.000000
5 task_slice:keyword_subscribe_pred SnorkelDataset valid f1 1.000000
6 task_slice:keyword_please_pred SnorkelDataset valid accuracy 1.000000
7 task_slice:keyword_please_pred SnorkelDataset valid f1 1.000000
8 task_slice:regex_check_out_pred SnorkelDataset valid accuracy 1.000000
9 task_slice:regex_check_out_pred SnorkelDataset valid f1 1.000000
10 task_slice:short_comment_pred SnorkelDataset valid accuracy 0.947368
11 task_slice:short_comment_pred SnorkelDataset valid f1 0.500000
12 task_slice:textblob_polarity_pred SnorkelDataset valid accuracy 1.000000
13 task_slice:textblob_polarity_pred SnorkelDataset valid f1 1.000000
14 task_slice:base_pred SnorkelDataset valid accuracy 0.925000
15 task_slice:base_pred SnorkelDataset valid f1 0.914286
16 task SnorkelDataset test accuracy 0.932000
17 task SnorkelDataset test f1 0.922374
18 task_slice:short_link_pred SnorkelDataset test accuracy 0.333333
19 task_slice:short_link_pred SnorkelDataset test f1 0.500000
20 task_slice:keyword_subscribe_pred SnorkelDataset test accuracy 0.861111
21 task_slice:keyword_subscribe_pred SnorkelDataset test f1 0.925373
22 task_slice:keyword_please_pred SnorkelDataset test accuracy 0.956522
23 task_slice:keyword_please_pred SnorkelDataset test f1 0.977778
24 task_slice:regex_check_out_pred SnorkelDataset test accuracy 1.000000
25 task_slice:regex_check_out_pred SnorkelDataset test f1 1.000000
26 task_slice:short_comment_pred SnorkelDataset test accuracy 0.967391
27 task_slice:short_comment_pred SnorkelDataset test f1 0.769231
28 task_slice:textblob_polarity_pred SnorkelDataset test accuracy 0.916667
29 task_slice:textblob_polarity_pred SnorkelDataset test f1 0.800000
30 task_slice:base_pred SnorkelDataset test accuracy 0.932000
31 task_slice:base_pred SnorkelDataset test f1 0.922374

Note: in this toy dataset, we see high variance in slice performance, because our dataset is so small that (i) there are few data points in the train split, giving little signal to learn over, and (ii) there are few data points in the test split, making our evaluation metrics very noisy. For a demonstration of data slicing deployed in state-of-the-art models, please see our SuperGLUE tutorials.


Recap

This tutorial walked through the process authoring slices, monitoring model performance on specific slices, and improving model performance using slice information. This programming abstraction provides a mechanism to heuristically identify critical data subsets. For more technical details about Slice-based Learning, stay tuned — our technical report is coming soon!