Member-only story
How To Make Your PyTorch Code Run Faster
Make your code run blazingly fast!
PyTorch is highly appreciated by researchers for its flexibility and has found its way into mainstream industries that want to stay abreast of the latest groundbreaking research.
In short, if you are a deep learning practitioner, you are going to be face to face with PyTorch sooner or later.
Today, I am going to cover some tricks that will greatly reduce the training time for your PyTorch models.
Data Loading
To load data for our models, we use torch.utils.data.DataLoader
, which creates a Python iterable over your dataset.
Let’s take a look at its signature:
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, *, prefetch_factor=2,
persistent_workers=False)
As we can see, there are many things under the hood. We need to focus on two in order to speed up our training.
num_workers
By default, its value is 0
. What does that mean? It means that the data will be loaded by the main process that is running your training code. This is highly inefficient because instead of training your model, the main process will focus solely on loading the data.
There is a better way. If we set num_workers
> 0
, then there will be a separate process that will handle the data loading. Your data loading will be asynchronous (i.e. it will not interfere with model loading).
How much of a speed improvement will this give me?
The more workers, the better. You can run with different num_workers
for one epoch and test which value for num_workers
works best for you.
In the worst case, if you just keep num_workers
= 1
, it will still give you a 1.2x speedup.
pin_memory
Note: This must be done along with the num_workers
> 0
condition. Otherwise, there is no speed improvement.