Skip to content

Commit

Permalink
Enable running the mnist_training sample without cuda (#6085)
Browse files Browse the repository at this point in the history
Signed-off-by: George Nash <[email protected]>
  • Loading branch information
georgen117 authored Dec 16, 2020
1 parent ac62cf8 commit 939cc9b
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions orttraining/pytorch_frontend_examples/mnist_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@

from onnxruntime.capi.ort_trainer import IODescription, ModelDescription, ORTTrainer
from mpi4py import MPI
from onnxruntime.capi._pybind_state import set_cuda_device_id
try:
from onnxruntime.capi._pybind_state import set_cuda_device_id
except ImportError:
pass

class NeuralNet(nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
Expand Down Expand Up @@ -116,13 +119,13 @@ def main():
args.local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) if ('OMPI_COMM_WORLD_LOCAL_RANK' in os.environ) else 0
args.world_rank = int(os.environ['OMPI_COMM_WORLD_RANK']) if ('OMPI_COMM_WORLD_RANK' in os.environ) else 0
args.world_size=comm.Get_size()
torch.cuda.set_device(args.local_rank)
if use_cuda:
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)
args.n_gpu = 1
set_cuda_device_id(args.local_rank)
else:
device = torch.device("cpu")
args.n_gpu = 1
set_cuda_device_id(args.local_rank)

input_size = 784
hidden_size = 500
Expand Down

0 comments on commit 939cc9b

Please sign in to comment.