Skip to content

Commit

Permalink
[doc] add small example to flight recorder tutorial (#3163)
Browse files Browse the repository at this point in the history
* [doc] add small example to flight recorder tutorial
---------

Co-authored-by: Svetlana Karslioglu <[email protected]>
  • Loading branch information
c-p-i-o and svekars authored Nov 26, 2024
1 parent 540bd0c commit 3ba3a46
Showing 1 changed file with 94 additions and 0 deletions.
94 changes: 94 additions & 0 deletions prototype_source/flight_recorder_tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,100 @@ Caveat: tabulate module is needed, so you might need pip install it first.
python fr_trace.py <dump dir containing trace files> -j [--selected-ranks i j k ...] [--pg-filters tp dp]
torchfrtrace <dump dir containing trace files> -j [--selected-ranks i j k ...] [--pg-filters 0 2]
An End-to-End Example
------------------------------------
To demonstrate the use of Flight Recorder, we will use a small program where we induce mismatched collectives.
In this example, ``rank0`` is programmed to do an additional collective.
The Flight Recorder dump files are saved to the ``/tmp`` directory.
For demonstration purposes, we named this program ``crash.py``.

.. note::
Please note that this is a simplified example. In real-world scenarios, the process would involve more
complexities.

.. code:: python
:caption: A crashing example
import torch
import torch.distributed as dist
import os
from datetime import timedelta
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
assert world_size <= 8, "world size must be less than or equal to 8"
os.environ["TORCH_NCCL_DEBUG_INFO_TEMP_FILE"] = "/tmp/trace_"
os.environ["TORCH_NCCL_DUMP_ON_TIMEOUT"] = "1"
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "2000"
device = torch.device(f"cuda:{local_rank}")
print(f"{local_rank=} {world_size=} master addr: {os.environ['MASTER_ADDR']} master port: {os.environ['MASTER_PORT']} {device=}")
# Initialize the process group with a small timeout so that jobs fail quickly
dist.init_process_group("nccl", world_size=world_size, rank=local_rank, timeout=timedelta(seconds=1))
a = torch.full((3, 4), float(local_rank), device=device)
# Write some collectives to populate Flight Recorder data
for i in range(2):
print(f"calling allreduce on {local_rank=}")
f = dist.all_reduce(a)
# rank0 is doing an additional collective
if local_rank == 0:
print("rank0 is doing an allreduce on tensor b, but other ranks forgot")
b = torch.full((4,5), float(local_rank), device=device)
f = dist.all_reduce(b)
for i in range(2):
print(f"calling allreduce on {local_rank=}")
f = dist.all_reduce(a)
torch.cuda.synchronize(device=device)
print(f"{local_rank=} exiting")
To run this program, use ``torchrun``:


.. code:: python
torchrun --nnodes=1 --nproc_per_node=2 crash.py
You should see two files in the ``/tmp`` directory:

.. code:: bash
$ls /tmp/trace*
# Expected output
/tmp/trace_0 /tmp/trace_1
Finally, to analyze these two files, we use the ``torchfrtrace`` command:

.. code:: bash
torchfrtrace --prefix "trace_" /tmp/
The output from the trace command is meant to be human-readable. It includes information about the
set of collectives that caused a failure.
The output for the command above is shown below.
We can clearly see that rank 1 did not join the "all_reduce" collective.

.. code-block:: bash
$torchfrtrace --prefix "trace_" /tmp/
Not all ranks joining collective 5 at entry 4
group info: 0:default_pg
collective: nccl:all_reduce
missing ranks: {1}
input sizes: [[3, 4]]
output sizes: [[3, 4]]
expected ranks: 2
collective state: scheduled
collective stack trace:
all_reduce at /home/cpio/local/pytorch/torch/distributed/distributed_c10d.py:2696
wrapper at /home/cpio/local/pytorch/torch/distributed/c10d_logger.py:83
<module> at /home/cpio/test/crash.py:44
Conclusion
----------
In this tutorial, we have learned about a new PyTorch diagnostic tool called Flight Recorder.
Expand Down

0 comments on commit 3ba3a46

Please sign in to comment.