Intermediate Computer Vision: Episode 2
In this lesson, you will explore how to use the data utilities in PyTorch to efficiently load your data. You can find the corresponding Jupyter notebook here. If you already know about PyTorch, you may want to skip ahead to episode 4 when we start modeling.
You will use the HaGRID dataset from the previous episode to create a custom torch.utils.data.Dataset
and corresponding torch.utils.data.Dataloader
to feed the data to a model. The end result will be a custom class called GestureDataset
that we can use to ensure reliable data pipelines in the remainder of this tutorial.
None of the patterns you will learn in this episode are unique to this example or to image data, so you will be able to adapt these lessons to work with any dataset you want to model with PyTorch.
1Why use a Torch DataLoader?
PyTorch's built-in Dataset
and Dataloader
objects simplify the processes between ingesting data and feeding it to a model.
The objects provide abstractions that address requirements common to most, if not all, deep learning scenarios.
- The
Dataset
defines the structure and how to fetch data instances. - The
Dataloader
leverages theDataset
to load batches of data that can easily be shuffled, sampled, transformed, etc.
Importantly for many computer vision cases, this PyTorch functionality is built to scale to training large networks on large datasets and there are many optimization avenues to explore for advanced users.
2What is a Torch DataLoader?
The torch.utils.data.Dataloader class helps you efficiently access batches from a dataset so you can feed them into your model.
The DataLoader
constructor has this signature:
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, *, prefetch_factor=2,
persistent_workers=False)
You can read more detail here.
The most important argument is the dataset
, which should be an instance of a torch.utils.data.DataLoader
object.
This object is what we will customize next.
Then we can use it to instantiate Dataloader
objects that follow the standard pattern for feeding data into a PyTorch model.
3Build a Torch Dataset
To create a Dataloader
, we need to pass it a Dataset
.
There are two ways to define a Torch Dataset
object, the map and the iterable style.
The difference is whether the torch.utils.data.Dataset
class defines the __len__
and __getitem__
functions (map type) or the __iter__
function (iterable type).
You can read more about this distinction here.
For now, all you need to know in the rest of this episode you will build a custom dataset with the HaGRID data GestureDataset
.
4Example: Components of the GestureDataset
In all remaining notebook examples and flows in this tutorial, we will use the GestureDataset
.
Much of the code is reused from the original source, which you can view here.
The end goal is to create a GestureDataset
object that we can easily use in model training code like the following snippet:
model = _initialize_model(model_name, checkpoint_path, device)
train_dataset = GestureDataset(is_train=True, transform=get_transform())
test_dataset = GestureDataset(is_train=False, transform=get_transform())
TrainClassifier.train(model, train_dataset, test_dataset, device)
This section shows how to implement the methods needed to use GestureDataset
, or any custom dataset, as depicted in the above code.
More than the details of this specific example code, the main takeaway of this section is that when working with a custom Dataset
class you need to:
- Your class should be a subclass of
torch.utils.data.Dataset
. - You need to define the constructor.
- You either need to define the
__getitem__
and__len__
methods, or define the__iter__
method. You can put whatever you want in the different methods of yourDataset
classes so long as the function signatures follow the PyTorch protocol.
4aThe Dataset Constructor
The Dataset
constructor is called upon to create the dataset.
For GestureDataset
, the constructor does the following:
- Assign class variables for a configuration file, transformations, and dataset labels.
- Split the images and their annotations into training and validation sets.
class GestureDataset(torch.utils.data.Dataset):
def __init__(self, is_train, conf, transform = None, is_test = False):
self.conf = conf
self.transform = transform
self.is_train = is_train
self.labels = {
label: num for (label, num) in zip(self.conf.dataset.targets, range(len(self.conf.dataset.targets)))
}
self.leading_hand = {"right": 0, "left": 1}
subset = self.conf.dataset.get("subset", None)
self.annotations = self.__read_annotations(subset)
users = self.annotations["user_id"].unique()
users = sorted(users)
random.Random(self.conf.random_state).shuffle(users)
train_users = users[: int(len(users) * 0.8)]
val_users = users[int(len(users) * 0.8) :]
self.annotations = self.annotations.copy()
if not is_test:
if is_train:
self.annotations = self.annotations[self.annotations["user_id"].isin(train_users)]
else:
self.annotations = self.annotations[self.annotations["user_id"].isin(val_users)]
...
4bGetting a Data Instance
The __getitem__
is a class method that allows instances of the Dataset
class to be indexed like a list using []
.
In our case, we want this function to take an integer index
and return an appropriately sized image and its label.
class GestureDataset(torch.utils.data.Dataset):
...
def __getitem__(self, index: int):
row = self.annotations.iloc[[index]].to_dict("records")[0]
image_resized, gesture, leading_hand = self.__prepare_image_target(
row["target"], row["name"], row["bboxes"], row["labels"], row["leading_hand"]
)
label = {"gesture": self.labels[gesture], "leading_hand": self.leading_hand[leading_hand]}
if self.transform is not None:
image_resized, label = self.transform(image_resized, label)
return image_resized, label
...
5Example: Using the GestureDataset
In this section, you will use the GestureDataset
to instantiate a Dataloader
and visualize one batch of images with their labels.
First, we will import dependencies.
import torch
from hagrid.classifier.dataset import GestureDataset
from hagrid.classifier.preprocess import get_transform
from hagrid.classifier.utils import collate_fn
from omegaconf import OmegaConf
from math import sqrt
import matplotlib.pyplot as plt
path_to_config = './hagrid/classifier/config/default.yaml'
conf = OmegaConf.load(path_to_config)
Then we instantiate the GestureDataset
implemented here.
train_dataset = GestureDataset(is_train=True, conf=conf, transform=get_transform())
Now, you can use the train_dataset
to create a data loader to request batches from.
BATCH_SIZE = 16
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=BATCH_SIZE,
num_workers=1, # change this to load data faster. feasible values depend on your machine specs.
collate_fn=collate_fn,
shuffle=True
# What happens
# to the image grid displayed by the view_batch function
# when you set shuffle=False in this constructor?
)
Here is a helper function to show the contents of a batch:
def view_batch(images, labels, batch_size):
import matplotlib.pyplot as plt
plt.ioff()
grid_dim = (
int(sqrt(batch_size)),
int(sqrt(batch_size)) + (1 if sqrt(batch_size) % 1 > 0 else 0)
)
fig, axes = plt.subplots(*grid_dim)
for i, (image, label) in enumerate(zip(images, labels)):
x, y = i//grid_dim[1], i%grid_dim[1]
image = image.permute(1,2,0)
axes[x, y].imshow(image)
axes[x, y].set_title(conf.dataset.targets[label['gesture']], fontsize=10)
[axes[x, y].spines[_dir].set_visible(False) for _dir in ['right', 'left', 'top', 'bottom']]
axes[x, y].set_xticks([])
axes[x, y].set_yticks([])
fig.tight_layout()
fig.savefig(fname='./dataloader-sample.png')
Now we can take the next batch from the train_dataloader
and view a grid of each image and its corresponding label.
images, labels = next(iter(train_dataloader))
view_batch(images, labels, BATCH_SIZE)
Nice! Getting a reliable data flow is a big step in any machine learning project. In this lesson, you have just scratched the surface of the tools PyTorch offers to help you do this. You learned about PyTorch datasets and data loaders in this episode. You saw to use them to efficiently and reliably load HaGRID dataset samples for training PyTorch models. Looking forward you will pair PyTorch data loaders with Metaflow features to extend the concepts when working with datasets in models in the cloud. See you there!