Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[3/N] Refine beginner tutorial by accelerator api #3170

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions beginner_source/chatbot_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,10 @@
import json


USE_CUDA = torch.cuda.is_available()
device = torch.device("cuda" if USE_CUDA else "cpu")
# If the current `accelerator <https://pytorch.org/docs/stable/torch.html#accelerators>`__ is available,
# we will use it. Otherwise, we use the CPU.
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(f"Using {device} device")


######################################################################
Expand Down Expand Up @@ -1318,16 +1320,16 @@ def evaluateInput(encoder, decoder, searcher, voc):
encoder_optimizer.load_state_dict(encoder_optimizer_sd)
decoder_optimizer.load_state_dict(decoder_optimizer_sd)

# If you have CUDA, configure CUDA to call
# If you have an accelerator, configure it to call
for state in encoder_optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.cuda()
state[k] = v.to(device)

for state in decoder_optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.cuda()
state[k] = v.to(device)

# Run training iterations
print("Starting Training!")
Expand Down
48 changes: 22 additions & 26 deletions beginner_source/introyt/tensors_deeper_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,34 +632,33 @@
# does this *without* changing ``a`` - you can see that when we print
# ``a`` again at the end, it retains its ``requires_grad=True`` property.
#
# Moving to GPU
# Moving to `Accelerator <https://pytorch.org/docs/stable/torch.html#accelerators>`__
# -------------
#
# One of the major advantages of PyTorch is its robust acceleration on
# CUDA-compatible Nvidia GPUs. (“CUDA” stands for *Compute Unified Device
# Architecture*, which is Nvidia’s platform for parallel computing.) So
# far, everything we’ve done has been on CPU. How do we move to the faster
# One of the major advantages of PyTorch is its robust acceleration on an
# `accelerator <https://pytorch.org/docs/stable/torch.html#accelerators>`__
# such as CUDA, MPS, MTIA, or XPU.
# So far, everything we’ve done has been on CPU. How do we move to the faster
# hardware?
#
# First, we should check whether a GPU is available, with the
# First, we should check whether an accelerator is available, with the
# ``is_available()`` method.
#
# .. note::
# If you do not have a CUDA-compatible GPU and CUDA drivers
# installed, the executable cells in this section will not execute any
# GPU-related code.
# If you do not have an accelerator, the executable cells in this section will not execute any
# accelerator-related code.
#

if torch.cuda.is_available():
print('We have a GPU!')
if torch.accelerator.is_available():
print('We have an accelerator!')
else:
print('Sorry, CPU only.')


##########################################################################
# Once we’ve determined that one or more GPUs is available, we need to put
# our data someplace where the GPU can see it. Your CPU does computation
# on data in your computer’s RAM. Your GPU has dedicated memory attached
# Once we’ve determined that one or more accelerators is available, we need to put
# our data someplace where the accelerator can see it. Your CPU does computation
# on data in your computer’s RAM. Your accelerator has dedicated memory attached
# to it. Whenever you want to perform a computation on a device, you must
# move *all* the data needed for that computation to memory accessible by
# that device. (Colloquially, “moving the data to memory accessible by the
Expand All @@ -669,34 +668,31 @@
# may do it at creation time:
#

if torch.cuda.is_available():
gpu_rand = torch.rand(2, 2, device='cuda')
if torch.accelerator.is_available():
gpu_rand = torch.rand(2, 2, device=torch.accelerator.current_accelerator())
print(gpu_rand)
else:
print('Sorry, CPU only.')


##########################################################################
# By default, new tensors are created on the CPU, so we have to specify
# when we want to create our tensor on the GPU with the optional
# when we want to create our tensor on the accelerator with the optional
# ``device`` argument. You can see when we print the new tensor, PyTorch
# informs us which device it’s on (if it’s not on CPU).
#
# You can query the number of GPUs with ``torch.cuda.device_count()``. If
# you have more than one GPU, you can specify them by index:
# You can query the number of accelerators with ``torch.accelerator.device_count()``. If
# you have more than one accelerator, you can specify them by index, take CUDA for example:
# ``device='cuda:0'``, ``device='cuda:1'``, etc.
#
# As a coding practice, specifying our devices everywhere with string
# constants is pretty fragile. In an ideal world, your code would perform
# robustly whether you’re on CPU or GPU hardware. You can do this by
# robustly whether you’re on CPU or accelerator hardware. You can do this by
# creating a device handle that can be passed to your tensors instead of a
# string:
#

if torch.cuda.is_available():
my_device = torch.device('cuda')
else:
my_device = torch.device('cpu')
my_device = torch.accelerator.current_accelerator() if torch.accelerator.is_available() else torch.device('cpu')
print('Device: {}'.format(my_device))

x = torch.rand(2, 2, device=my_device)
Expand All @@ -718,12 +714,12 @@
# It is important to know that in order to do computation involving two or
# more tensors, *all of the tensors must be on the same device*. The
# following code will throw a runtime error, regardless of whether you
# have a GPU device available:
# have an accelerator device available, take CUDA for example:
#
# .. code-block:: python
#
# x = torch.rand(2, 2)
# y = torch.rand(2, 2, device='gpu')
# y = torch.rand(2, 2, device='cuda')
# z = x + y # exception will be thrown
#

Expand Down
6 changes: 4 additions & 2 deletions beginner_source/knowledge_distillation_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@
import torchvision.transforms as transforms
import torchvision.datasets as datasets

# Check if GPU is available, and if not, use the CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Check if the current `accelerator <https://pytorch.org/docs/stable/torch.html#accelerators>`__
# is available, and if not, use the CPU
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(f"Using {device} device")

######################################################################
# Loading CIFAR-10
Expand Down
25 changes: 11 additions & 14 deletions beginner_source/nn_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@
# we'll write `log_softmax` and use it. Remember: although PyTorch
# provides lots of prewritten loss functions, activation functions, and
# so forth, you can easily write your own using plain python. PyTorch will
# even create fast GPU or vectorized CPU code for your function
# even create fast accelerator or vectorized CPU code for your function
# automatically.

def log_softmax(x):
Expand Down Expand Up @@ -827,38 +827,35 @@ def __iter__(self):
fit(epochs, model, loss_func, opt, train_dl, valid_dl)

###############################################################################
# Using your GPU
# Using your `Accelerator <https://pytorch.org/docs/stable/torch.html#accelerators>`__
# ---------------
#
# If you're lucky enough to have access to a CUDA-capable GPU (you can
# If you're lucky enough to have access to an accelerator such as CUDA (you can
# rent one for about $0.50/hour from most cloud providers) you can
# use it to speed up your code. First check that your GPU is working in
# use it to speed up your code. First check that your accelerator is working in
# Pytorch:

print(torch.cuda.is_available())
# If the current accelerator is available, we will use it. Otherwise, we use the CPU.
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(f"Using {device} device")

###############################################################################
# And then create a device object for it:

dev = torch.device(
"cuda") if torch.cuda.is_available() else torch.device("cpu")

###############################################################################
# Let's update ``preprocess`` to move batches to the GPU:
# Let's update ``preprocess`` to move batches to the accelerator:


def preprocess(x, y):
return x.view(-1, 1, 28, 28).to(dev), y.to(dev)
return x.view(-1, 1, 28, 28).to(device), y.to(device)


train_dl, valid_dl = get_data(train_ds, valid_ds, bs)
train_dl = WrappedDataLoader(train_dl, preprocess)
valid_dl = WrappedDataLoader(valid_dl, preprocess)

###############################################################################
# Finally, we can move our model to the GPU.
# Finally, we can move our model to the accelerator.

model.to(dev)
model.to(device)
opt = optim.SGD(model.parameters(), lr=lr, momentum=0.9)

###############################################################################
Expand Down