Skip to content

Commit

Permalink
updating branch (#55)
Browse files Browse the repository at this point in the history
* template paper

* template paper

* Revert "template paper"

This reverts commit 95d70f3.

* Revert "template paper"

This reverts commit dc59caa.

* safer eval model

* reverted change

* added badge conda (#47)

* Added conda information

* clamped log_hazard to prevent torch.Inf

* try fix CI builds (#53)

* try fix builds

* partial rollback

* fix build dependency

* make formatting mandatory & reformat kaplan meier test

* add unreleased note

* edited

* adding sphinx-issues for Changelog links

* reverted change

---------

Co-authored-by: corolth1 <[email protected]>

---------

Co-authored-by: melodiemonod <[email protected]>
Co-authored-by: Peter Krusche (Novartis) <[email protected]>
  • Loading branch information
3 people authored Sep 26, 2024
1 parent ba6ffb5 commit 8ac3811
Show file tree
Hide file tree
Showing 8 changed files with 44 additions and 18 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build-release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
run: python -m twine check dist/*

- name: Upload artifact
uses: actions/upload-artifact@v2
uses: actions/upload-artifact@v4
with:
name: Python-package
path: dist
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/codeqc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ jobs:
shell: bash -l {0}
run: |
conda activate torchsurv
./dev/codeqc.sh
./dev/codeqc.sh check
- name: Tests
shell: bash -l {0}
run: |
Expand Down
25 changes: 18 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@

![CodeQC](https://github.com/Novartis/torchsurv/actions/workflows/codeqc.yml/badge.svg?branch=main)
![Docs](https://github.com/Novartis/torchsurv/actions/workflows/docs.yml/badge.svg?branch=main)
[![PyPI - Version](https://img.shields.io/pypi/v/torchsurv)](https://pypi.org/project/torchsurv/)
[![arXiv](https://img.shields.io/badge/arXiv-2404.10761-f9f107.svg)](https://arxiv.org/abs/2404.10761)
[![status](https://camo.githubusercontent.com/22fa65b2a659780cddfac609463c5fe719e3ea82a28eb7a61e24b7c4e40eb56d/68747470733a2f2f6a6f73732e7468656f6a2e6f72672f7061706572732f30326437343936646132623963633334663961366530346361626632323938642f7374617475732e737667)](https://joss.theoj.org/papers/02d7496da2b9cc34f9a6e04cabf2298d)
[![PyPI - Version](https://img.shields.io/pypi/v/torchsurv?)](https://pypi.org/project/torchsurv/)
[![Conda](https://img.shields.io/conda/v/conda-forge/torchsurv?label=conda)](https://anaconda.org/conda-forge/torchsurv)
[![arXiv](https://img.shields.io/badge/arXiv-2404.10761-f9f107.svg?)](https://arxiv.org/abs/2404.10761)
[![Documentation](https://img.shields.io/badge/GithubPage-Sphinx-blue)](https://opensource.nibr.com/torchsurv/)
[![Downloads](https://static.pepy.tech/badge/torchsurv)](https://pepy.tech/project/torchsurv)
[![PyPI Downloads](https://img.shields.io/pypi/dm/torchsurv.svg?label=PyPI%20downloads)](
https://pypi.org/project/torchsurv/)
[![Conda Downloads](https://img.shields.io/conda/dn/conda-forge/torchsurv.svg?label=Conda%20downloads)](
https://anaconda.org/conda-forge/torchsurv)

`TorchSurv` is a Python package that serves as a companion tool to perform deep survival modeling within the `PyTorch` environment. Unlike existing libraries that impose specific parametric forms on users, `TorchSurv` enables the use of custom `PyTorch`-based deep survival models. With its lightweight design, minimal input requirements, full `PyTorch` backend, and freedom from restrictive survival model parameterizations, `TorchSurv` facilitates efficient survival model implementation, particularly beneficial for high-dimensional input data scenarios.

Expand Down Expand Up @@ -43,15 +46,23 @@ cindex.compare(cindexB)

## Installation and dependencies

First, install the package:

First, install the package using either [PyPI]([https://pypi.org/](https://pypi.org/project/torchsurv/)) or [Conda]([https://anaconda.org/anaconda/conda](https://anaconda.org/conda-forge/torchsurv))

- Using conda (`recommended`)
```bash
conda install conda-forge::torchsurv
```
- Using PyPI
```bash
pip install torchsurv
```

or for local installation (from package root / clone of this git repository):
- Using for local installation (`latest version`)

```bash
git clone <repo>
cd <repo>
pip install -e .
```

Expand Down Expand Up @@ -237,4 +248,4 @@ If you use this project in academic work or publications, we appreciate citing i
primaryClass={cs.LG},
doi={https://doi.org/10.48550/arXiv.2404.10761}
}
```
```
2 changes: 1 addition & 1 deletion dev/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ channels:
- conda-forge
- pytorch
dependencies:
- build=0.7.0
- python-build=1.2.2
- pep517=0.13.0
- numpy=1.26.4
- pandas=2.2.0
Expand Down
10 changes: 10 additions & 0 deletions docs/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
Change log
=========

Version 0.1.3 (unreleased)
--------------------------

* Tutorial dataset error on momentum.ipynb #50
* Fix issue #48 - log_hazard returns torch.Inf
* Fix warning with Spearman correlation #41
* Added in-depth statistical background to link AUC to C-index #39
* Created Conda Forge version #47
* Updated CICD builds #53

Version 0.1.2
-------------

Expand Down
4 changes: 2 additions & 2 deletions src/torchsurv/loss/momentum.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def forward(
self.memory_k.append(self.survtuple(*list(estimate)))
return loss

@torch.no_grad()
@torch.no_grad() # deactivates autograd
def infer(self, inputs: torch.Tensor) -> torch.Tensor:
"""Evaluate data with target network
Expand All @@ -183,7 +183,7 @@ def infer(self, inputs: torch.Tensor) -> torch.Tensor:
[ 0.9771, -0.8513]])
"""
self.target.eval() # Disable training tricks (augmentation, dropout, etc..)
self.target.eval() # notify all your layers that you are in eval mode
return self.target(inputs)

def _bank_loss(self) -> torch.Tensor:
Expand Down
9 changes: 7 additions & 2 deletions src/torchsurv/loss/weibull.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,9 @@ def log_hazard(
>>> for t in torch.tensor([100.0, 150.0]): log_hazard(log_params, time=t) # Subject-specific log hazard at multiple new times
tensor([ 1.1280, -0.0372, -3.9767, 1.0757])
tensor([ 1.2330, -0.1062, -4.1680, 1.1999])
>>> log_params *= 1e2 # Increase scale
>>> log_hazard(log_params, time, all_times = False) # Check for Torch.Inf values
tensor([-1.0000e+10, -2.3197e+01, -6.8385e+01, -1.0000e+10])
"""

log_scale, log_shape = _check_log_shape(log_params).unbind(1)
Expand All @@ -247,11 +250,13 @@ def log_hazard(
f"Dimension mismatch: 'time' ({len(time)}) does not match the length of 'log_params' ({len(log_params)})."
)

return (
return torch.clamp(
log_shape
- log_scale
+ torch.expm1(log_shape)
* (torch.log(torch.clip(time, 1e-100, torch.inf)) - log_scale)
* (torch.log(torch.clip(time, 1e-100, torch.inf)) - log_scale),
min=-TORCH_CLAMP_VALUE,
max=TORCH_CLAMP_VALUE,
)


Expand Down
6 changes: 3 additions & 3 deletions tests/test_kaplan_meier.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,9 @@ def test_kaplan_meier_prediction_error_raised(self):
for batch in batch_container.batches:
(train_time, train_event, test_time, *_) = batch

train_event[
-1
] = False # if last event is censoring, the last KM is > 0 and it cannot predict beyond this time
train_event[-1] = (
False # if last event is censoring, the last KM is > 0 and it cannot predict beyond this time
)
km = KaplanMeierEstimator()
km(train_event, train_time, censoring_dist=False)

Expand Down

0 comments on commit 8ac3811

Please sign in to comment.