FastAI: Multi-Label Classification [Chapter-6]

PyTorch and fastai have two main classes for representing and accessing a training set or validation set:

  • Dataset:: A collection that returns a tuple of your independent and dependent variable for a single item
  • DataLoader:: An iterator that provides a stream of mini-batches, where each mini-batch is a tuple of a batch of independent variables and a batch of dependent variables

On top of these, fastai provides two classes for bringing your training and validation sets together:

  • Datasets:: An object that contains a training Dataset and a validation Dataset
  • DataLoaders:: An object that contains a training DataLoader and a validation DataLoader

The Learner object contains four main things –

  • Model
  • DataLoaders object
  • Optimizer
  • loss function to use

PyTorch provides F.binary_cross_entropy and its module equivalent nn.BCELoss calculate cross entropy on a one-hot-encoded target, but do not include the initial sigmoid. Normally for the one-hot-encoded targets you’ll want F.binary_cross_entropy_with_log ( or nn.BCEWithLogitsLoss) which do both sigmoid and binary cross entropy in a single function.

In fastai we do not need to specify the loss function. Based on the DataLoaders definition, fastai knows which loss function to pick. In case of multi-label classification, it will use nn.BCEWithLogitsLoss by default.

fastai will automatically try to pick the right loss function from the data you built, but if you are using pure PyTorch to build your DataLoaders, make sure you think hard when you have to decide on your choice of loss function, and remember that you most probably want:

  • nn.CrossEntropyLoss for single-label classification
  • nn.BCEWithLogitsLoss for multi-label classification
  • nn.MSELoss for regression

Notebook –

Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

Leave a Reply

Fill in your details below or click an icon to log in: Logo

You are commenting using your account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s