A compilation of concepts I want to remember...

 » Home
 » About Me
 » Github

IMDB-WIKI: notes on refactoring data preprocess pipeline

07 Apr 2018 » imdb, deeplearning, machinelearning

This note is an update to IMDB-WIKI: trying a small model for age classification where I attempted to simplify the objective of age classification by reducing the number of classes and applying a learning model with a relative smaller capacity. That said, rather than modelling, the main focus of the exercise was to handle data in relatively raw and unedited form to extract and load into a format readily consumable by a deep learning model.

The first iteration was sufficient for personal use, but the sloppiness of the project quickly surfaced as others attempted to use the scripts that I had put together.

I had the opportunity to work on the repo again, and refactored the scripts to allow for easier use by others, though some work still is required. See the below points, related to the rework.

A few takeaways

  1. The new implementation is using PyTorch, and the Dataset API to extract, transform, and load the data after preprocessing. In contrast to the original Chainer implementation, I needed paths to the images, so I added the option to the imdb_preprocess.py to return input features as paths to images, by setting the --get-paths flag to True.

  2. The model for this exercise was a pretrained VGG16 model with redefinition of the classifier block. [1] I had trouble extracting the input size to nn.Sequential() so had to reference the documentation.

self.classifier = nn.Sequential(
    nn.Linear(512 * 7 * 7, 4096),
      nn.Linear(4096, 4096),
      nn.Linear(4096, num_classes),
  1. The faces only IMDB data set contains images of all sizes and dimensions, with float values normalized between 0 and 1. As VGG16 takes in 3 channels, I cropped and reduced dimensions to gray scale, then took an additional step to rescale to 0-255 and apply np.uint8(). [Questionable if this is the best way, and would like to here other suggestions if any.]. The final step was to convert to 3 channels, which is just copying the 1d image, across the 3 channels. The source can be found here. The conversion to uint8 is required as the torchvision.transforms.ToPILImage() method used in transformer.py does not take floats at the time of this writing. [2]

  2. On a second pass of the data, I noticed ages well beyond the valid upper range, and included a range check into imdb_preprocess.py. I would imagine a closer investigation to surface further possible improvements.

Next steps at some point

As this 2nd iteration was just refactoring the data extraction and loading process, I have not spent much time on the modeling side and have included a pre-trained VGG16 implementation as a starting point.

  1. http://pytorch.org/docs/0.2.0/_modules/torchvision/models/vgg.html
  2. https://github.com/pytorch/vision/issues/4