Skip to content

Commit

Permalink
Merge pull request #132 from rasbt/update-torchmetrics
Browse files Browse the repository at this point in the history
update for newer versions of torchmetrics
  • Loading branch information
rasbt authored May 23, 2023
2 parents 322f57b + 79be110 commit 92e2320
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 8 deletions.
22 changes: 21 additions & 1 deletion ERRATA/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,32 @@ should be

$$\frac{\partial L}{\partial w_{j, l}^{(l)}}$$

## Chapter 18
## Chapter 12

**Page 380**

We use `TensorDataset` even though we defined the custom `JointDataset`

## Chapter 13

**Page 431**

When using Torchmetrics 0.8.0 or newer, the following lines

```python
self.train_acc = Accuracy()
self.valid_acc = Accuracy()
self.test_acc = Accuracy()
```

need to be changed to

```python
self.train_acc = Accuracy(task="multiclass", num_classes=10)
self.valid_acc = Accuracy(task="multiclass", num_classes=10)
self.test_acc = Accuracy(task="multiclass", num_classes=10)
```

## Chapter 15


Expand Down
17 changes: 13 additions & 4 deletions ch13/ch13_part3_lightning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@
"import torch \n",
"import torch.nn as nn \n",
"\n",
"from torchmetrics import __version__ as torchmetrics_version\n",
"from pkg_resources import parse_version\n",
"\n",
"from torchmetrics import Accuracy"
]
},
Expand All @@ -163,9 +166,15 @@
" super().__init__()\n",
" \n",
" # new PL attributes:\n",
" self.train_acc = Accuracy()\n",
" self.valid_acc = Accuracy()\n",
" self.test_acc = Accuracy()\n",
" \n",
" if parse_version(torchmetrics_version) > parse_version(0.8):\n",
" self.train_acc = Accuracy(task=\"multiclass\", num_classes=10)\n",
" self.valid_acc = Accuracy(task=\"multiclass\", num_classes=10)\n",
" self.test_acc = Accuracy(task=\"multiclass\", num_classes=10)\n",
" else:\n",
" self.train_acc = Accuracy()\n",
" self.valid_acc = Accuracy()\n",
" self.test_acc = Accuracy()\n",
" \n",
" # Model similar to previous section:\n",
" input_size = image_shape[0] * image_shape[1] * image_shape[2] \n",
Expand Down Expand Up @@ -1092,7 +1101,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
"version": "3.10.10"
}
},
"nbformat": 4,
Expand Down
14 changes: 11 additions & 3 deletions ch13/ch13_part3_lightning.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# coding: utf-8


from pkg_resources import parse_version
import sys
from python_environment_check import check_packages
import pytorch_lightning as pl
import torch
import torch.nn as nn
from torchmetrics import __version__ as torchmetrics_version
from torchmetrics import Accuracy
from torch.utils.data import DataLoader
from torch.utils.data import random_split
Expand Down Expand Up @@ -72,9 +74,15 @@ def __init__(self, image_shape=(1, 28, 28), hidden_units=(32, 16)):
super().__init__()

# new PL attributes:
self.train_acc = Accuracy()
self.valid_acc = Accuracy()
self.test_acc = Accuracy()

if parse_version(torchmetrics_version) > parse_version(0.8):
self.train_acc = Accuracy(task="multiclass", num_classes=10)
self.valid_acc = Accuracy(task="multiclass", num_classes=10)
self.test_acc = Accuracy(task="multiclass", num_classes=10)
else:
self.train_acc = Accuracy()
self.valid_acc = Accuracy()
self.test_acc = Accuracy()

# Model similar to previous section:
input_size = image_shape[0] * image_shape[1] * image_shape[2]
Expand Down

0 comments on commit 92e2320

Please sign in to comment.