PyTorch + DASK

pyTorch Distributed

This guide walks through how you can quickly distribute deep learning models using pyTorch in a distributed environment. It assumes you have already completed the basic Python & MPI tutorial, Python & Dask tutorial, and understand the basics of torch.

High level "what's going on here"

We're going to be relying heavily on PyTorch's "torch.nn.parallel.DistributedDataParallel", using Dask to request the resources on our nodes. This strategy will effectively send different parts of our dataset to different nodes - i.e., different "batches" - and solve for those in a distributed way. A master will then aggregate the results (gradients, by taking the mean), provide the solution for an Epoch, update the weights, send the new model out to the nodes, and repeat for each Epoch.

This strategy for parallelization is a bit different from what we present in our scikit-learn examples - i.e., here we're going to be fitting a single model using all of our nodes, instead of multiple models with different hyper-parameters.

Requirements

conda install -c pytorch pytorch=1.12.1 torchvision=0.13.1

conda install dask pandas

conda install dask-jobqueue -c conda-forge

You should have the UC Merced satellite imagery downloaded, as specified in the basics of torch tutorial.

Job File

As before with our dask job files, there isn't much special here. We're only requesting one node, which will put our other requests in for us.

Because we will be printing to a file to generate our log, you may also want to familiarize yourself with the linux terminal watch command.

Note that if you have to kill your job, you can type in: qstat -u <yourusername>

Identify the jobID that is asscoiated with the name specified in this job script - here it is demojob. You can then do: qdel <jobID> to kill the job. This will then kill all Dask jobs as well (though it takes a minute for them to clear out).

Directory structure

To run succesfully, the python file will need to be modified in a few ways:

  1. You need to point the data loader towards where you unzipped the images.

  2. You need to create a "checkpoints" folder, and point the script to where it is. This is where models are saved on disk each epoch.

  3. You need to change a few paths for file ouputs, including for a log file and a file that the nodes use to exchange an IP address.

The Python File

All of the magic is done within Python, courtesy of Dask. The below is heavily commented, so I'll let those comments speak for themselves!

Note that by default it is running 5 epochs using 5 nodes - this may take a little while, so you can consider scaling.

Last updated