Skip to content

Commit

Permalink
review #2 (#68)
Browse files Browse the repository at this point in the history
  • Loading branch information
tcoroller authored Nov 8, 2024
1 parent 89781b0 commit 63c2080
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 189 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
<p align="center">
<img src="https://github.com/Novartis/torchsurv/blob/main/docs/source/logo_firecamp.png" width="300">
<!-- <img src="https://github.com/Novartis/torchsurv/blob/main/docs/source/logo_firecamp.png" width="300"> -->
<img src="./docs/source/logo_firecamp.png" width="300">

</p>

# Deep survival analysis made easy
Expand Down Expand Up @@ -55,13 +57,11 @@ cindex.p_value(method="noether", alternative="two_sided")
cindex.compare(cindexB)
```


## Installation and dependencies


First, install the package using either [PyPI]([https://pypi.org/](https://pypi.org/project/torchsurv/)) or [Conda]([https://anaconda.org/anaconda/conda](https://anaconda.org/conda-forge/torchsurv))

- Using conda (`recommended`)
- Using conda (**recommended**)
```bash
conda install conda-forge::torchsurv
```
Expand Down
117 changes: 66 additions & 51 deletions docs/notebooks/introduction.ipynb

Large diffs are not rendered by default.

178 changes: 44 additions & 134 deletions docs/notebooks/momentum.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
"\n",
"### Dependencies\n",
"\n",
"To run this notebooks, dependencies must be installed. the recommended method is to use our development conda environment (`preferred`). Instruction can be found [here](https://opensource.nibr.com/torchsurv/devnotes.html#set-up-a-development-environment-via-conda) to install all optional dependencies. The other method is to install only required packages using the command line below:\n"
"To run this notebooks, dependencies must be installed. the recommended method is to use our development conda environment (**preferred**). Instruction can be found [here](https://opensource.nibr.com/torchsurv/devnotes.html#set-up-a-development-environment-via-conda) to install all optional dependencies. The other method is to install only required packages using the command line below:\n"
]
},
{
Expand Down Expand Up @@ -104,11 +104,34 @@
{
"cell_type": "code",
"execution_count": 6,
"id": "ebaf967b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CUDA-enabled GPU/TPU is available.\n"
]
}
],
"source": [
"# Detect available accelerator; Downgrade batch size if only CPU available\n",
"if any([torch.cuda.is_available(), torch.backends.mps.is_available()]):\n",
" print(\"CUDA-enabled GPU/TPU is available.\")\n",
" BATCH_SIZE = 500 # batch size for training\n",
"else:\n",
" print(\"No CUDA-enabled GPU found, using CPU.\")\n",
" BATCH_SIZE = 50 # batch size for training"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "794004c5-588c-4590-ae96-c6d9e52109ff",
"metadata": {},
"outputs": [],
"source": [
"BATCH_SIZE = 500 # batch size for training\n",
"EPOCHS = 2 # number of epochs to train\n",
"FAST_DEV_RUN = None # Quick prototype, set to None for full training"
]
Expand All @@ -135,7 +158,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 8,
"id": "4abbc6b0",
"metadata": {},
"outputs": [],
Expand All @@ -153,7 +176,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 9,
"id": "ebf5caff",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -201,7 +224,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 10,
"id": "c216fa33-de09-4be2-82cc-83cb73db3a42",
"metadata": {},
"outputs": [],
Expand All @@ -217,7 +240,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 11,
"id": "8056a675-fbce-4f4b-86c0-ab7dd924e4b1",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -255,7 +278,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 12,
"id": "1e7a2c7e-a1ef-42fa-ba74-1d33a1dcf2f3",
"metadata": {},
"outputs": [],
Expand All @@ -266,7 +289,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 13,
"id": "3f577acf-a821-41a4-8544-318617755d1e",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -296,7 +319,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 14,
"id": "430079cc-4fad-4da2-8ea5-aa904c41ec0e",
"metadata": {},
"outputs": [
Expand All @@ -319,21 +342,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1: 100%|██████████| 11/11 [01:02<00:00, 0.18it/s, loss_step=218.0, val_loss_step=282.0, cindex_step=0.652, val_loss_epoch=287.0, cindex_epoch=0.665, loss_epoch=228.0]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"`Trainer.fit` stopped: `max_epochs=2` reached.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1: 100%|██████████| 11/11 [01:02<00:00, 0.18it/s, loss_step=218.0, val_loss_step=282.0, cindex_step=0.652, val_loss_epoch=287.0, cindex_epoch=0.665, loss_epoch=228.0]\n"
"Epoch 0: 0%| | 0/11 [00:00<?, ?it/s] "
]
}
],
Expand All @@ -344,34 +353,10 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": null,
"id": "7854deb3-52f8-4a92-b38f-ff304bf82a34",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Testing DataLoader 0: 100%|██████████| 20/20 [01:22<00:00, 0.24it/s]\n",
"────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n",
" Test metric DataLoader 0\n",
"────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n",
" cindex_epoch 0.6686268448829651\n",
" val_loss_epoch -458.2862548828125\n",
"────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n"
]
},
{
"data": {
"text/plain": [
"[{'val_loss_epoch': -458.2862548828125, 'cindex_epoch': 0.6686268448829651}]"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"# Test the model\n",
"trainer.test(model_regular, datamodule)"
Expand All @@ -391,7 +376,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": null,
"id": "bab50a3f-5670-4264-b2c0-4eccb5f48624",
"metadata": {},
"outputs": [],
Expand All @@ -408,51 +393,10 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": null,
"id": "00473ec0-9f44-47f2-824d-02dcc92dba7d",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"GPU available: True (mps), used: True\n",
"TPU available: False, using: 0 TPU cores\n",
"IPU available: False, using: 0 IPUs\n",
"HPU available: False, using: 0 HPUs\n",
"\n",
" | Name | Type | Params\n",
"-----------------------------------\n",
"0 | model | Momentum | 22.3 M\n",
"-----------------------------------\n",
"11.2 M Trainable params\n",
"11.2 M Non-trainable params\n",
"22.3 M Total params\n",
"89.366 Total estimated model params size (MB)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1: 100%|██████████| 110/110 [01:18<00:00, 1.40it/s, loss_step=57.10, val_loss_step=63.80, cindex_step=0.848, val_loss_epoch=59.80, cindex_epoch=0.841, loss_epoch=58.10]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"`Trainer.fit` stopped: `max_epochs=2` reached.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1: 100%|██████████| 110/110 [01:18<00:00, 1.40it/s, loss_step=57.10, val_loss_step=63.80, cindex_step=0.848, val_loss_epoch=59.80, cindex_epoch=0.841, loss_epoch=58.10]\n"
]
}
],
"outputs": [],
"source": [
"# Define trainer\n",
"trainer = L.Trainer(\n",
Expand All @@ -470,34 +414,10 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": null,
"id": "6441c1ea-b87f-4ff7-92dd-a8d7abf8daa5",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Testing DataLoader 0: 100%|██████████| 200/200 [01:38<00:00, 2.03it/s]\n",
"────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n",
" Test metric DataLoader 0\n",
"────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n",
" cindex_epoch 0.858147144317627\n",
" val_loss_epoch 72.23859405517578\n",
"────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n"
]
},
{
"data": {
"text/plain": [
"[{'val_loss_epoch': 72.23859405517578, 'cindex_epoch': 0.858147144317627}]"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"# Validate the model\n",
"trainer.test(model_momentum, datamodule_momentum)"
Expand All @@ -513,7 +433,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": null,
"id": "855bda61",
"metadata": {},
"outputs": [],
Expand All @@ -527,7 +447,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": null,
"id": "38b1f7d1",
"metadata": {},
"outputs": [],
Expand All @@ -553,20 +473,10 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": null,
"id": "112e2e5d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Cindex (regular) = 0.6948477029800415\n",
"Cindex (momentum) = 0.8578558564186096\n",
"Compare (p-value) = 2.1650459203215178e-11\n"
]
}
],
"outputs": [],
"source": [
"print(f\"Cindex (regular) = {cindex1(log_hz1, torch.ones_like(y).bool(), y.float())}\")\n",
"print(f\"Cindex (momentum) = {cindex2(log_hz2, torch.ones_like(y).bool(), y.float())}\")\n",
Expand Down

0 comments on commit 63c2080

Please sign in to comment.