diff --git a/.github/workflows/python-package-conda.yml b/.github/workflows/python-package-conda.yml index f38ddc6..a8bfe6a 100644 --- a/.github/workflows/python-package-conda.yml +++ b/.github/workflows/python-package-conda.yml @@ -2,13 +2,16 @@ name: Python Package using Conda on: [push] +env: + CODECOV_TOKEN: 5454ef86-3f2b-45a7-8df0-636d3044ae13 + jobs: build-linux: runs-on: ubuntu-latest strategy: fail-fast: false matrix: - python-version: ['3.8', '3.9', '3.10'] + python-version: ['3.8', '3.9'] steps: - uses: actions/checkout@v2 @@ -38,8 +41,9 @@ jobs: run: | conda install pytest pip install pytest-html + pip install pytest-cov pip install -e . - pytest test --html=${{ matrix.python-version }}-results.html --self-contained-html + pytest test --html=${{ matrix.python-version }}-results.html --self-contained-html --cov=./ --cov-report=xml - name: Move artifacts shell: bash -l {0} run: mv test.log ${{ matrix.python-version }}-test.log @@ -50,4 +54,7 @@ jobs: name: ${{ matrix.python-version }}-artifacts path: | ${{ matrix.python-version }}-results.html - ${{ matrix.python-version }}-test.log \ No newline at end of file + ${{ matrix.python-version }}-test.log + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v2 + if: always() diff --git a/README.md b/README.md index db999e0..e186171 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,61 @@ +# trough [![Python Package using Conda](https://github.com/gregstarr/trough/actions/workflows/python-package-conda.yml/badge.svg)](https://github.com/gregstarr/trough/actions/workflows/python-package-conda.yml) +[![codecov](https://codecov.io/gh/gregstarr/trough/branch/master/graph/badge.svg?token=QNCESQ41EW)](https://codecov.io/gh/gregstarr/trough) -# trough -mid latitude ionospheric trough research +![GitHub](https://img.shields.io/github/license/gregstarr/trough) +![GitHub last commit](https://img.shields.io/github/last-commit/gregstarr/trough?color=blue&style=flat) +![Lines of code](https://img.shields.io/tokei/lines/github/gregstarr/trough?color=orange) +![GitHub Repo stars](https://img.shields.io/github/stars/gregstarr/trough?style=social) + +### Example + +![Example](example.png) + +### Features +- Download Madrigal TEC, OMNI and DMSP SSUSI data +- Process datasets into more convenient `xarray` data structures and save as NetCDF +- Automatically label main ionospheric trough + +# Usage + +1. Clone Repo +2. create conda environment using `environment.yml` (if you have trouble with apexpy, install it first) +3. install trough with `pip install -e .` +4. copy `config.json.example` --> `config.json` and change any options you want +5. run with `python -m trough config.json` +6. wait for it to finish (can take several days if you are running 5+ years) +7. add `import trough` in your code and access the data using `trough.get_data` + +### Config +#### Main Options +| Config Option | Definition | +| --- |-------------------------------------------------------------------------------------------------------------------------| +| base_dir | base directory of trough downloads and processing, directories for downloading and processing will be created from here | +| madrigal_user_name | name supplied to MadrigalWeb | +| madrigal_user_email | email supplied to MadrigalWeb | +| madrigal_user_affil | affiliation supplied to MadrigalWeb | +| nasa_spdf_download_method | "http" or "ftp" (default) | +| lat_res | latitude resolution of processed TEC maps (degrees Apex magnetic latitude) | +| lon_res | longitude resolution of processed TEC maps (degrees Apex magnetic longitude) | +| time_res_unit | time resolution units (passed to `np.timedelta64`) | +| time_res_n | time resolution in units specified above (passed to `np.timedelta64`) | +| script_name | which script to run, available scripts are in `trough/scripts.py` | +| start_date | start date of interval (YYYYMMDD, YYYYMMDD_hh, YYYYMMDD_hhmm, or YYYYMMDD_hhmmss) | +| end_date | end date of interval, see "start_date" for format | +| keep_download | whether or not to keep the downloaded files (not recommended) | +| trough_id_params | trough labeling algorithm parameters, see below | + +#### Trough Labeling Options +| Config Option | Definition | +| --- |-------------------------------------------------------------------------------------| +| bg_est_shape | background estimation filter size in pixels [time, latitude, longitude] | +| model_weight_max | maximum value of L2 regularization before multiplication by coefficient `l2_weight` | +| rbf_bw | RBF bandwidth, number of pixels to half weight | +| tv_hw | total variation horizontal weight | +| tv_vw | total variation vertical weight | +| l2_weight | L2 regularization coefficient | +| tv_weight | TV regularization coefficient | +| perimeter_th | minimum perimeter for a connected component in a label image | +| area_th | minimum area for a connected component in a label image | +| threshold | score threshold below which a pixel is not labeled as MIT | +| closing_rad | radius for disk structuring element passed to `skimage.morphology.binary_closing` | diff --git a/config.json.example b/config.json.example new file mode 100644 index 0000000..240034a --- /dev/null +++ b/config.json.example @@ -0,0 +1,32 @@ +{ + "base_dir": "path/to/trough_directory", + "trough_id_params": { + "bg_est_shape": [ + 1, + 19, + 17 + ], + "model_weight_max": 15, + "rbf_bw": 1, + "tv_hw": 2, + "tv_vw": 1, + "l2_weight": 0.09, + "tv_weight": 0.15, + "perimeter_th": 30, + "area_th": 30, + "threshold": 1, + "closing_rad": 0 + }, + "madrigal_user_name": "your_name", + "madrigal_user_email": "your_email@email.com", + "madrigal_user_affil": "your_affiliation", + "nasa_spdf_download_method": "ftp", + "lat_res": 1, + "lon_res": 2, + "time_res_unit": "h", + "time_res_n": 1, + "script_name": "full_run", + "start_date": "20100101", + "end_date": "20220101", + "keep_download": false +} \ No newline at end of file diff --git a/environment.yml b/environment.yml index dfa70d2..8b16646 100644 --- a/environment.yml +++ b/environment.yml @@ -2,17 +2,17 @@ channels: - anaconda - conda-forge dependencies: - - numpy - - scipy - - h5py - - scikit-image - - scikit-learn - - appdirs - - bs4 - - pandas - - xarray - - bottleneck - - cvxpy + - numpy==1.21.2 + - scipy==1.7.3 + - h5py==3.6.0 + - scikit-image==0.19.1 + - scikit-learn==1.0.2 + - appdirs==1.4.4 + - bs4==4.10.0 + - pandas==1.3.5 + - xarray==0.20.1 + - bottleneck==1.3.2 + - cvxpy==1.1.18 - pip - pip: - - madrigalWeb + - madrigalWeb==3.2 diff --git a/example.png b/example.png new file mode 100644 index 0000000..eab3186 Binary files /dev/null and b/example.png differ diff --git a/test/subtest.py b/test/subtest.py index 8f44255..1221a5f 100644 --- a/test/subtest.py +++ b/test/subtest.py @@ -3,7 +3,7 @@ start_date = datetime(2020, 9, 8, 9) end_date = datetime(2020, 9, 9, 12) -data = trough.get_data(start_date, end_date) +data = trough.get_data(start_date, end_date, 'north') print(data['tec'].shape) print(data['kp'].shape) print(data['labels'].shape) diff --git a/test/test_arb.py b/test/test_arb.py index f8cd6ee..beb7f8b 100644 --- a/test/test_arb.py +++ b/test/test_arb.py @@ -14,13 +14,13 @@ def test_file_list(): - start_date = datetime(2001, 1, 1, 12, 0, 0) - end_date = datetime(2001, 1, 2, 12, 0, 0) + start_date = datetime(2001, 1, 4, 12, 0, 0) + end_date = datetime(2001, 1, 5, 12, 0, 0) with TemporaryDirectory() as tempdir: cache_fn = Path(tempdir) / "file_list.json" cache = {} for sat in ['f16', 'f17', 'f18', 'f19']: - for doy in [1, 2]: + for doy in [3, 4, 5]: cache_key = f"{sat}_{2001}_{doy}" cache[cache_key] = [f'{cache_key}_file_1', f'{cache_key}_file_2'] with open(cache_fn, 'w') as f: @@ -75,14 +75,18 @@ def test_download_arb(test_dates, download_dir): ) def test_process_arb(download_dir, processed_dir, test_dates, dt, mlt_vals): start, end = test_dates - correct_times = np.arange(np.datetime64(start, 's'), np.datetime64(end, 's'), dt) + correct_times = np.arange(np.datetime64(start, 's'), np.datetime64(end, 's') + dt, dt) processed_file = Path(processed_dir) / 'arb_test.nc' - process_interval(start, end, processed_file, download_dir, mlt_vals, dt) - assert processed_file.exists() - data = xr.open_dataarray(processed_file) - assert data.shape == (correct_times.shape[0], mlt_vals.shape[0]) - assert (data.mlt == mlt_vals).all().item() - assert (data.time == correct_times).all().item() + for hemisphere in ['north', 'south']: + process_interval(start, end, hemisphere, processed_file, download_dir, mlt_vals, dt) + assert processed_file.exists() + data = xr.open_dataarray(processed_file) + data.load() + assert data.shape == (correct_times.shape[0], mlt_vals.shape[0]) + assert (data.mlt == mlt_vals).all().item() + assert (data.time == correct_times).all().item() + data.close() + processed_file.unlink() def test_process_arb_out_of_range(download_dir, processed_dir, test_dates): @@ -90,37 +94,38 @@ def test_process_arb_out_of_range(download_dir, processed_dir, test_dates): start, end = [date - timedelta(days=100) for date in test_dates] processed_file = Path(processed_dir) / 'arb_test.nc' with pytest.raises(InvalidProcessDates): - process_interval(start, end, processed_file, download_dir, config.get_mlt_vals(), dt) + process_interval(start, end, 'north', processed_file, download_dir, config.get_mlt_vals(), dt) def test_get_arb_data(download_dir, processed_dir, test_dates): start, end = test_dates dt = np.timedelta64(1, 'h') mlt = config.get_mlt_vals() - correct_times = np.arange(np.datetime64(start), np.datetime64(end), dt) - processed_file = get_arb_paths(start, end, processed_dir)[0] - process_interval(start, end, processed_file, download_dir, mlt, dt) - data = get_arb_data(start, end, processed_dir) + correct_times = np.arange(np.datetime64(start), np.datetime64(end) + dt, dt) + processed_file = get_arb_paths(start, end, 'north', processed_dir)[0] + process_interval(start, end, 'north', processed_file, download_dir, mlt, dt) + data = get_arb_data(start, end, 'north', processed_dir) assert data.shape == (correct_times.shape[0], mlt.shape[0]) assert (data.mlt == mlt).all().item() assert (data.time == correct_times).all().item() def test_scripts(test_dates): + start, end = test_dates with TemporaryDirectory() as base_dir: with config.temp_config(base_dir=base_dir) as cfg: - scripts.download_arb(*test_dates) + scripts.download_arb(start, end) arb_files = list(Path(cfg.download_arb_dir).glob('*')) assert len(arb_files) > 0 - data, times = _get_downloaded_arb_data(*test_dates, cfg.download_arb_dir) - assert min(times) < test_dates[0] - assert max(times) > test_dates[-1] - scripts.process_arb(*test_dates) - data = get_arb_data(*test_dates, cfg.processed_arb_dir) + data, times = _get_downloaded_arb_data(start, end, cfg.download_arb_dir) + assert min(times) < start + assert max(times) > end + scripts.process_arb(start, end) + data = get_arb_data(start, end, 'north', cfg.processed_arb_dir) data.load() dt = np.timedelta64(1, 'h') mlt = config.get_mlt_vals() - correct_times = np.arange(np.datetime64(test_dates[0]), np.datetime64(test_dates[-1]), dt) + correct_times = np.arange(np.datetime64(test_dates[0]), np.datetime64(test_dates[-1]) + dt, dt) assert data.shape == (correct_times.shape[0], mlt.shape[0]) assert (data.mlt == mlt).all().item() assert (data.time == correct_times).all().item() diff --git a/test/test_tec.py b/test/test_tec.py index 77a7268..af89053 100644 --- a/test/test_tec.py +++ b/test/test_tec.py @@ -33,7 +33,35 @@ def test_calculate_bins(): dims=['time', 'x', 'y'] ) be = np.array([-.5, 4.5, 9.5]) - out_tec = calculate_bins(data, be, be) + out_tec = calculate_bins(data, be, be, 'north') + assert out_tec.shape == (2, 2) + assert out_tec[0, 0] == 10 / 25 + assert out_tec[0, 1] == 20 / 25 + assert out_tec[1, 0] == 30 / 25 + assert out_tec[1, 1] == 0 + + +def test_calculate_bins_south(): + mlat = np.arange(10)[:, None] * np.ones((1, 10)) * -1 + mlt = np.arange(10)[None, None, :] * np.ones((1, 10, 1)) + tec = np.zeros((1, 10, 10)) + tec[0, 0, 0] = 10 + tec[0, 0, -1] = 20 + tec[0, -1, 0] = 30 + times = np.ones(1) * np.nan + data = xr.DataArray( + tec, + coords={ + 'time': times, + 'x': np.arange(10), + 'y': np.arange(10), + 'mlat': (('x', 'y'), mlat), + 'mlt': (('time', 'x', 'y'), mlt) + }, + dims=['time', 'x', 'y'] + ) + be = np.array([-.5, 4.5, 9.5]) + out_tec = calculate_bins(data, be, be, 'south') assert out_tec.shape == (2, 2) assert out_tec[0, 0] == 10 / 25 assert out_tec[0, 1] == 20 / 25 @@ -46,7 +74,7 @@ def test_file_list(): end_date = datetime(2001, 1, 2, 12, 0, 0) with TemporaryDirectory() as tempdir: cache_fn = Path(tempdir) / "file_list.json" - cache = {'100139613': 'file_1', '100139351': 'file_2'} + cache = {'100139613': 'file_1', '100139351': 'file_2', '100119343': 'file_3', '100139428': 'file_4'} with open(cache_fn, 'w') as f: json.dump(cache, f) downloader = MadrigalTecDownloader(tempdir, 'gstarr', 'gstarr@bu.edu', 'bu') @@ -100,17 +128,24 @@ def test_download_tec(test_dates, download_dir): ) def test_process_tec(download_dir, process_dir, test_dates, dt, mlt_bins, mlat_bins): start, end = test_dates - correct_times = np.arange(np.datetime64(start, 's'), np.datetime64(end, 's'), dt) + correct_times = np.arange(np.datetime64(start, 's'), np.datetime64(end, 's') + dt, dt) mlt_vals = (mlt_bins[:-1] + mlt_bins[1:]) / 2 mlat_vals = (mlat_bins[:-1] + mlat_bins[1:]) / 2 processed_file = Path(process_dir) / 'tec_test.nc' - process_interval(start, end, processed_file, download_dir, dt, mlat_bins, mlt_bins) - assert processed_file.exists() - data = xr.open_dataarray(processed_file) - assert data.shape == (correct_times.shape[0], mlat_vals.shape[0], mlt_vals.shape[0]) - assert (data.mlt == mlt_vals).all().item() - assert (data.mlat == mlat_vals).all().item() - assert (data.time == correct_times).all().item() + for hemisphere in ['north', 'south']: + process_interval(start, end, hemisphere, processed_file, download_dir, dt, mlat_bins, mlt_bins) + assert processed_file.exists() + data = xr.open_dataarray(processed_file) + data.load() + h = 1 if hemisphere == 'north' else -1 + try: + assert data.shape == (correct_times.shape[0], mlat_vals.shape[0], mlt_vals.shape[0]) + assert (data.mlt == mlt_vals).all().item() + assert (data.mlat == h * mlat_vals).all().item() + assert (data.time == correct_times).all().item() + finally: + data.close() + processed_file.unlink() def test_process_tec_out_of_range(download_dir, process_dir, test_dates): @@ -118,43 +153,49 @@ def test_process_tec_out_of_range(download_dir, process_dir, test_dates): start, end = [date - timedelta(days=100) for date in test_dates] processed_file = Path(process_dir) / 'tec_test.nc' with pytest.raises(InvalidProcessDates): - process_interval(start, end, processed_file, download_dir, dt, config.get_mlat_bins(), config.get_mlt_bins()) + process_interval(start, end, 'north', processed_file, download_dir, dt, config.get_mlat_bins(), config.get_mlt_bins()) def test_get_tec_data(download_dir, process_dir, test_dates): start, end = test_dates dt = np.timedelta64(1, 'h') - mlt_bins = config.get_mlt_vals() - mlat_bins = config.get_mlat_vals() + mlt_bins = config.get_mlt_bins() + mlat_bins = config.get_mlat_bins() mlt_vals = (mlt_bins[:-1] + mlt_bins[1:]) / 2 mlat_vals = (mlat_bins[:-1] + mlat_bins[1:]) / 2 - correct_times = np.arange(np.datetime64(start), np.datetime64(end), dt) - processed_file = get_tec_paths(start, end, process_dir)[0] - process_interval(start, end, processed_file, download_dir, dt, mlat_bins, mlt_bins) - data = get_tec_data(start, end, process_dir) - assert data.shape == (correct_times.shape[0], mlat_vals.shape[0], mlt_vals.shape[0]) - assert (data.mlt == mlt_vals).all().item() - assert (data.mlat == mlat_vals).all().item() - assert (data.time == correct_times).all().item() + correct_times = np.arange(np.datetime64(start), np.datetime64(end) + dt, dt) + for hemisphere in ['north', 'south']: + processed_file = get_tec_paths(start, end, hemisphere, process_dir)[0] + process_interval(start, end, hemisphere, processed_file, download_dir, dt, mlat_bins, mlt_bins) + data = get_tec_data(start, end, hemisphere, process_dir) + h = 1 if hemisphere == 'north' else -1 + assert data.shape == (correct_times.shape[0], mlat_vals.shape[0], mlt_vals.shape[0]) + assert (data.mlt == mlt_vals).all().item() + assert (data.mlat == h * mlat_vals).all().item() + assert (data.time == correct_times).all().item() def test_scripts(test_dates): + start, end = test_dates with TemporaryDirectory() as base_dir: with config.temp_config(base_dir=base_dir) as cfg: - scripts.download_tec(*test_dates) + scripts.download_tec(start, end) tec_files = list(Path(cfg.download_tec_dir).glob('*')) assert len(tec_files) > 0 - data = _get_downloaded_tec_data(*test_dates, cfg.download_tec_dir) + data = _get_downloaded_tec_data(start, end, cfg.download_tec_dir) assert data.time.values[0] < np.datetime64(test_dates[0], 's') assert data.time.values[-1] > np.datetime64(test_dates[-1], 's') - scripts.process_tec(*test_dates) - data = get_tec_data(*test_dates, cfg.processed_tec_dir) - data.load() - dt = np.timedelta64(1, 'h') - mlt_vals = config.get_mlt_vals() - mlat_vals = config.get_mlat_vals() - correct_times = np.arange(np.datetime64(test_dates[0]), np.datetime64(test_dates[-1]), dt) - assert data.shape == (correct_times.shape[0], mlat_vals.shape[0], mlt_vals.shape[0]) - assert (data.mlt == mlt_vals).all().item() - assert (data.mlat == mlat_vals).all().item() - assert (data.time == correct_times).all().item() + scripts.process_tec(start, end) + + for hemisphere in ['north', 'south']: + data = get_tec_data(start, end, hemisphere, cfg.processed_tec_dir) + data.load() + dt = np.timedelta64(1, 'h') + mlt_vals = config.get_mlt_vals() + mlat_vals = config.get_mlat_vals() + correct_times = np.arange(np.datetime64(test_dates[0]), np.datetime64(test_dates[-1]) + dt, dt) + h = 1 if hemisphere == 'north' else -1 + assert data.shape == (correct_times.shape[0], mlat_vals.shape[0], mlt_vals.shape[0]) + assert (data.mlt == mlt_vals).all().item() + assert (data.mlat == h * mlat_vals).all().item() + assert (data.time == correct_times).all().item() diff --git a/test/test_trough.py b/test/test_trough.py index efe266e..2c4e429 100644 --- a/test/test_trough.py +++ b/test/test_trough.py @@ -3,8 +3,9 @@ from datetime import datetime, timedelta from tempfile import TemporaryDirectory from pathlib import Path +import pytest -from trough import config, _trough, scripts +from trough import config, _trough, scripts, get_data from trough._config import TroughIdParams @@ -95,7 +96,7 @@ def test_postprocess(): dims=['time', 'mlt'] ), }) - _trough.postprocess(data, perimeter_th=50) + _trough.postprocess(data, 'north', perimeter_th=50) labels = data['labels'].values[0] assert labels[good_labels].all() assert not labels[small_reject].any() @@ -150,43 +151,69 @@ def test_get_tec_troughs(): """Verify that get_tec_troughs can detect an actual trough, verify that high troughs are rejected using auroral boundary data """ + n_hours = 12 start_date = datetime(2015, 10, 7, 6, 0, 0) - end_date = start_date + timedelta(hours=12) + end_date = start_date + timedelta(hours=n_hours) params = TroughIdParams(bg_est_shape=(1, 19, 19), model_weight_max=5, l2_weight=.1, tv_weight=.05, tv_hw=2) with config.temp_config(trough_id_params=params): scripts.download_all(start_date, end_date) scripts.process_all(start_date, end_date) - data = _trough.label_trough_interval(start_date, end_date, config.trough_id_params, config.processed_tec_dir, - config.processed_arb_dir, config.processed_omni_file) - - labels = data['labels'].values - assert labels.shape == (12, 60, 180) - assert labels[1, 20:30, 60:120].mean() > .5 + data_north = _trough.label_trough_interval( + start_date, end_date, config.trough_id_params, 'north', + config.processed_tec_dir, config.processed_arb_dir, config.processed_omni_file + ) + data_south = _trough.label_trough_interval( + start_date, end_date, config.trough_id_params, 'south', + config.processed_tec_dir, config.processed_arb_dir, config.processed_omni_file + ) + + labels_north = data_north['labels'].values + labels_south = data_south['labels'].values + assert labels_north.shape == (n_hours + 1, 60, 180) + assert labels_north[1, 20:30, 60:120].mean() > .5 + assert labels_south.shape == (n_hours + 1, 60, 180) + assert labels_south[3, 20:30, 60:80].mean() > .5 for i in range(12): - assert labels[i][(data.mlat > data['arb'][i] + 3).values].sum() == 0 - - -def test_process_trough_interval(test_dates): - start_date, end_date = test_dates + assert labels_north[i][(data_north.mlat > data_north['arb'][i] + 3).values].sum() == 0 + assert labels_south[i][(data_south.mlat < data_south['arb'][i] - 3).values].sum() == 0 + + +@pytest.mark.parametrize('dates', + [ + [datetime(2021, 1, 3, 6, 0, 0), datetime(2021, 1, 3, 12, 0, 0)], + [datetime(2020, 12, 31, 20, 0, 0), datetime(2021, 1, 1, 4, 0, 0)] + ]) +def test_process_trough_interval(dates): + start_date, end_date = dates + n_times = 1 + ((end_date - start_date) / timedelta(hours=1)) scripts.download_all(start_date, end_date) scripts.process_all(start_date, end_date) - data = _trough.label_trough_interval(start_date, end_date, config.trough_id_params, config.processed_tec_dir, - config.processed_arb_dir, config.processed_omni_file) + data = _trough.label_trough_interval( + start_date, end_date, config.trough_id_params, 'north', + config.processed_tec_dir, config.processed_arb_dir, config.processed_omni_file + ) assert 'labels' in data assert 'tec' in data - assert data.time.shape[0] == 6 + assert data.time.shape[0] == n_times assert data.mlat.shape[0] == 60 assert data.mlt.shape[0] == 180 assert np.nanmean(data['tec'].values[data['labels'].values]) < np.nanmean(data['tec'].values[~data['labels'].values]) -def test_script(test_dates): +@pytest.mark.parametrize('dates', + [ + [datetime(2021, 1, 3, 6, 0, 0), datetime(2021, 1, 3, 12, 0, 0)], + [datetime(2020, 12, 31, 20, 0, 0), datetime(2021, 1, 1, 4, 0, 0)] + ]) +def test_script(dates): + start_date, end_date = dates + n_times = 1 + ((end_date - start_date) / timedelta(hours=1)) with TemporaryDirectory() as tempdir: with config.temp_config(base_dir=tempdir): - scripts.full_run(*test_dates) - path = Path(config.processed_labels_dir) / "labels_2021.nc" - assert path.exists() - data = xr.open_dataarray(path) + scripts.full_run(*dates) + n_files = len([p for p in Path(config.processed_labels_dir).glob('labels*.nc')]) + assert n_files == (end_date.year - start_date.year + 1) * 2 + data = get_data(start_date, end_date, 'north') data.load() - assert data.size > 1 + assert data.time.shape[0] == n_times data.close() diff --git a/trough/_arb.py b/trough/_arb.py index a1bfe70..7553bb5 100644 --- a/trough/_arb.py +++ b/trough/_arb.py @@ -6,33 +6,37 @@ from scipy.interpolate import interp1d import logging import warnings + try: import h5py -except ImportError as e: - warnings.warn(f"Packages required for recreating dataset not installed: {e}") +except ImportError as imp_err: + warnings.warn(f"Packages required for recreating dataset not installed: {imp_err}") from trough import config, utils from trough.exceptions import InvalidProcessDates - -_arb_fields = ['YEAR', 'DOY', 'TIME', 'ALTITUDE', 'MODEL_NORTH_GEOGRAPHIC_LATITUDE', 'MODEL_NORTH_GEOGRAPHIC_LONGITUDE'] +_arb_fields = [ + 'YEAR', 'DOY', 'TIME', 'ALTITUDE', + 'MODEL_NORTH_GEOGRAPHIC_LATITUDE', 'MODEL_NORTH_GEOGRAPHIC_LONGITUDE', + 'MODEL_SOUTH_GEOGRAPHIC_LATITUDE', 'MODEL_SOUTH_GEOGRAPHIC_LONGITUDE' +] logger = logging.getLogger(__name__) -def get_arb_paths(start_date, end_date, processed_dir): +def get_arb_paths(start_date, end_date, hemisphere, processed_dir): file_dates = np.arange( np.datetime64(start_date, 'Y'), - (np.datetime64(end_date, 's') - np.timedelta64(1, 'h')).astype('datetime64[Y]') + 1, + (np.datetime64(end_date, 's')).astype('datetime64[Y]') + 1, np.timedelta64(1, 'Y') ) file_dates = utils.decompose_datetime64(file_dates) - return [Path(processed_dir) / f"arb_{d[0]:04d}.nc" for d in file_dates] + return [Path(processed_dir) / f"arb_{hemisphere}_{d[0]:04d}.nc" for d in file_dates] -def get_arb_data(start_date, end_date, processed_dir=None): +def get_arb_data(start_date, end_date, hemisphere, processed_dir=None): if processed_dir is None: processed_dir = config.processed_arb_dir - data = xr.concat([xr.open_dataarray(file) for file in get_arb_paths(start_date, end_date, processed_dir)], 'time') + data = xr.concat([xr.open_dataarray(file) for file in get_arb_paths(start_date, end_date, hemisphere, processed_dir)], 'time') return data.sel(time=slice(start_date, end_date)) @@ -43,8 +47,8 @@ def parse_arb_fn(path): def _get_downloaded_arb_data(start_date, end_date, input_dir): - start_date -= timedelta(hours=3) - end_date += timedelta(hours=3) + start_date -= timedelta(days=1) + end_date += timedelta(days=1) data = {field: [] for field in _arb_fields} data['sat'] = [] for path in Path(input_dir).glob('*.NC'): @@ -64,45 +68,43 @@ def _get_downloaded_arb_data(start_date, end_date, input_dir): return data, times -def process_interval(start_date, end_date, output_fn, input_dir, mlt_vals, sample_dt): +def process_interval(start_date, end_date, hemisphere, output_fn, input_dir, mlt_vals, sample_dt): logger.info(f"processing arb data for {start_date, end_date}") - ref_times = np.arange(np.datetime64(start_date, 's'), np.datetime64(end_date, 's'), sample_dt) + ref_times = np.arange(np.datetime64(start_date, 's'), np.datetime64(end_date, 's') + sample_dt, sample_dt) apex = Apex(date=start_date) - data, times = _get_downloaded_arb_data(start_date, end_date, input_dir) + arb_data, times = _get_downloaded_arb_data(start_date, end_date, input_dir) if times.size == 0 or min(times) > ref_times[0] or max(times) < ref_times[-1]: logger.error(f"times size: {times.size}") if len(times) > 0: - logger.error(f"times: {min(times)} - {max(times)}") + logger.error(f"{min(times)=} {ref_times[0]=} {max(times)=} {ref_times[-1]=}") raise InvalidProcessDates("Need to download full data range before processing") logger.info(f"{times.shape[0]} time points") sort_idx = np.argsort(times) - times = times[sort_idx] + mlat = np.empty((times.shape[0], mlt_vals.shape[0])) for i, idx in enumerate(sort_idx): - lat = data['MODEL_NORTH_GEOGRAPHIC_LATITUDE'][idx] - lon = data['MODEL_NORTH_GEOGRAPHIC_LONGITUDE'][idx] - height = np.mean(data['ALTITUDE'][idx]) + height = np.mean(arb_data['ALTITUDE'][idx]) + lat = arb_data[f'MODEL_{hemisphere.upper()}_GEOGRAPHIC_LATITUDE'][idx] + lon = arb_data[f'MODEL_{hemisphere.upper()}_GEOGRAPHIC_LONGITUDE'][idx] apx_lat, mlt = apex.convert(lat, lon, 'geo', 'mlt', height, utils.datetime64_to_datetime(times[idx])) mlat[i] = np.interp(mlt_vals, mlt, apx_lat, period=24) - logger.info(f"ref times: [{ref_times[0]}, {ref_times[-1]}]") - interpolator = interp1d(times.astype('datetime64[s]').astype(float), mlat, axis=0, bounds_error=False) + good_mask = np.mean(abs(mlat - np.median(mlat, axis=0, keepdims=True)), axis=1) < 1 + interpolator = interp1d( + times.astype('datetime64[s]').astype(float)[sort_idx][good_mask], + mlat[good_mask], + axis=0, bounds_error=False + ) mlat = interpolator(ref_times.astype(float)) - data = xr.DataArray(mlat, coords={'time': ref_times, 'mlt': mlt_vals}, dims=['time', 'mlt']) + data = xr.DataArray( + mlat, + coords={'time': ref_times, 'mlt': mlt_vals}, + dims=['time', 'mlt'] + ) + logger.info(f"ref times: [{ref_times[0]}, {ref_times[-1]}]") data.to_netcdf(output_fn) -def check_processed_data_interval(start, end, processed_file): - if processed_file.exists(): - logger.info(f"processed file already exists {processed_file=}, checking...") - try: - data_check = get_arb_data(start, end, processed_file.parent) - if not data_check.isnull().all(dim=['mlt']).any().item(): - logger.info(f"downloaded data already processed {processed_file=}, checking...") - return False - except Exception as e: - logger.info(f"error reading processed file {processed_file=}: {e}, removing and reprocessing") - processed_file.unlink() - return True +check_processed_data_interval = utils.get_data_checker(get_arb_data) def process_auroral_boundary_dataset(start_date, end_date, download_dir=None, process_dir=None, mlt_vals=None, dt=None): @@ -117,10 +119,12 @@ def process_auroral_boundary_dataset(start_date, end_date, download_dir=None, pr Path(process_dir).mkdir(exist_ok=True, parents=True) for year in range(start_date.year, end_date.year + 1): - output_file = Path(process_dir) / f"arb_{year:04d}.nc" start = max(start_date, datetime(year, 1, 1)) - end = min(end_date, datetime(year + 1, 1, 1)) + end = utils.datetime64_to_datetime(np.datetime64(datetime(year + 1, 1, 1)) - dt) + end = min(end_date, end) if end - start <= timedelta(hours=1): continue - if check_processed_data_interval(start, end, output_file): - process_interval(start, end, output_file, download_dir, mlt_vals, dt) + for hemisphere in ['north', 'south']: + output_file = Path(process_dir) / f"arb_{hemisphere}_{year:04d}.nc" + if check_processed_data_interval(start, end, dt, hemisphere, output_file): + process_interval(start, end, hemisphere, output_file, download_dir, mlt_vals, dt) diff --git a/trough/_config.py b/trough/_config.py index f2f13fa..57d76be 100644 --- a/trough/_config.py +++ b/trough/_config.py @@ -84,6 +84,7 @@ def __init__(self, config_path=None): self.lon_res = 2 self.time_res_unit = 'h' self.time_res_n = 1 + self.mlat_min = 30 self.script_name = 'full_run' self.start_date = None @@ -99,10 +100,10 @@ def get_config_name(self): return f"{cfg['script_name']}_{cfg['start_date']}_{cfg['end_date']}_config.json" def get_mlat_bins(self): - return np.arange(29.5, 90, self.lat_res) + return np.arange(self.mlat_min - self.lat_res / 2, 90, self.lat_res) def get_mlat_vals(self): - return np.arange(29.5 + self.lat_res / 2, 90, self.lat_res) + return np.arange(self.mlat_min, 90, self.lat_res) def get_mlt_bins(self): return np.arange(-12, 12 + 24 / 360, self.lon_res * 24 / 360) diff --git a/trough/_download.py b/trough/_download.py index a425223..f0bcdaf 100644 --- a/trough/_download.py +++ b/trough/_download.py @@ -15,8 +15,8 @@ import h5py from madrigalWeb import madrigalWeb import bs4 -except ImportError as e: - warnings.warn(f"Packages required for recreating dataset not installed: {e}") +except ImportError as imp_err: + warnings.warn(f"Packages required for recreating dataset not installed: {imp_err}") from trough.exceptions import InvalidConfiguration @@ -166,7 +166,7 @@ def _download_files(self, files): def _get_file_list(self, start_date, end_date): logger.info("Getting file list...") - experiments = sorted(self._get_tec_experiments(start_date - timedelta(hours=3), end_date + timedelta(hours=3))) + experiments = sorted(self._get_tec_experiments(start_date - timedelta(days=1), end_date + timedelta(days=1))) logger.info(f"found {len(experiments)} experiments") tec_files = {} for i, experiment in enumerate(experiments): @@ -232,8 +232,8 @@ def _download_ftp_file(self, file, local_path): _download_ftp_file(self.server, file, local_path) def _get_file_list(self, start_date, end_date): - new_start_date = start_date - timedelta(hours=3) - new_end_date = end_date + timedelta(hours=3) + new_start_date = start_date - timedelta(days=1) + new_end_date = end_date + timedelta(days=1) files = { str(year): [f'/pub/data/omni/low_res_omni/omni2_{year:4d}.dat'] for year in range(new_start_date.year, new_end_date.year + 1) @@ -274,11 +274,11 @@ def _download_files(self, files): return local_files def _get_file_list(self, start_date, end_date): - start_date -= timedelta(hours=3) - end_date += timedelta(hours=3) - n_days = math.ceil((end_date - start_date) / timedelta(days=1)) + date1 = start_date - timedelta(days=1) + date2 = end_date + timedelta(days=1) + n_days = math.ceil((date2 - date1) / timedelta(days=1)) logger.info(f"getting files for {n_days} days") - days = [start_date + timedelta(days=t) for t in range(n_days)] + days = [date1 + timedelta(days=t) for t in range(n_days)] arb_files = {} for i, day in enumerate(days): if len(days) > 100 and not (i % (len(days) // 100)): diff --git a/trough/_omni.py b/trough/_omni.py index a3fff05..18eb142 100644 --- a/trough/_omni.py +++ b/trough/_omni.py @@ -32,5 +32,5 @@ def open_downloaded_omni_file(fn): def process_omni_dataset(input_dir, output_fn): output_path = Path(output_fn) output_path.parent.mkdir(exist_ok=True, parents=True) - data = xr.concat([open_downloaded_omni_file(path) for path in Path(input_dir).glob('*.dat')], 'time') + data = xr.concat([open_downloaded_omni_file(path) for path in Path(input_dir).glob('*.dat')], 'time').sortby('time') data.to_netcdf(output_fn) diff --git a/trough/_tec.py b/trough/_tec.py index 208aeeb..b04134c 100644 --- a/trough/_tec.py +++ b/trough/_tec.py @@ -9,8 +9,8 @@ import warnings try: import h5py -except ImportError as e: - warnings.warn(f"Packages required for recreating dataset not installed: {e}") +except ImportError as imp_err: + warnings.warn(f"Packages required for recreating dataset not installed: {imp_err}") from trough import config, utils from trough.exceptions import InvalidProcessDates @@ -60,32 +60,38 @@ def open_madrigal_file(fn): ) -def get_tec_paths(start_date, end_date, processed_dir): +def get_tec_paths(start_date, end_date, hemisphere, processed_dir): file_dates = np.arange( np.datetime64(start_date, 'M'), - (np.datetime64(end_date, 's') - np.timedelta64(1, 'h')).astype('datetime64[M]') + 1, + (np.datetime64(end_date, 's')).astype('datetime64[M]') + 1, np.timedelta64(1, 'M') ) file_dates = utils.decompose_datetime64(file_dates) - return [Path(processed_dir) / f"tec_{d[0]:04d}_{d[1]:02d}.nc" for d in file_dates] + return [Path(processed_dir) / f"tec_{hemisphere}_{d[0]:04d}_{d[1]:02d}.nc" for d in file_dates] -def get_tec_data(start_date, end_date, processed_dir=None): +def get_tec_data(start_date, end_date, hemisphere, processed_dir=None): if processed_dir is None: processed_dir = config.processed_tec_dir - data = xr.concat([xr.open_dataarray(file) for file in get_tec_paths(start_date, end_date, processed_dir)], 'time') + data = xr.concat([xr.open_dataarray(file) for file in get_tec_paths(start_date, end_date, hemisphere, processed_dir)], 'time') return data.sel(time=slice(start_date, end_date)) -def calculate_bins(data, mlat_bins, mlt_bins): +def calculate_bins(data, mlat_bins, mlt_bins, hemisphere): """Calculates TEC in MLAT - MLT bins. Executed in process pool. """ if data.shape == (): tec = np.ones((mlat_bins.shape[0] - 1, mlt_bins.shape[0] - 1)) * np.nan else: - mask = np.isfinite(data.values) + if hemisphere == 'north': + mlat_grid = np.broadcast_to(data.mlat, data.shape) + elif hemisphere == 'south': + mlat_grid = np.broadcast_to(data.mlat, data.shape) * -1 + else: + raise ValueError(f"Invalid hemisphere: {hemisphere}, valid = ['north', 'south']") + mask = np.isfinite(data.values) & (mlat_grid >= 0) tec = binned_statistic_2d( - np.broadcast_to(data.mlat, data.shape)[mask], + mlat_grid[mask], data.mlt.values[mask], data.values[mask], statistic='mean', @@ -119,12 +125,12 @@ def get_mag_coords(apex, mad_data): ) -def process_interval(start_date, end_date, output_fn, input_dir, sample_dt, mlat_bins, mlt_bins): +def process_interval(start_date, end_date, hemisphere, output_fn, input_dir, sample_dt, mlat_bins, mlt_bins): """Processes an interval of madrigal data and writes to files. """ logger.info(f"processing tec data for {start_date, end_date}") - calc_bins = functools.partial(calculate_bins, mlat_bins=mlat_bins, mlt_bins=mlt_bins) - ref_times = np.arange(np.datetime64(start_date, 's'), np.datetime64(end_date, 's'), sample_dt) + calc_bins = functools.partial(calculate_bins, mlat_bins=mlat_bins, mlt_bins=mlt_bins, hemisphere=hemisphere) + ref_times = np.arange(np.datetime64(start_date, 's'), np.datetime64(end_date, 's') + sample_dt, sample_dt) logger.info(f"ref times: {ref_times.shape}, {ref_times[0]=}, {ref_times[-1]=}") mad_data = _get_downloaded_tec_data(start_date, end_date, input_dir) if mad_data.shape == () or min(mad_data.time.values) > ref_times[0] or max(mad_data.time.values) < ref_times[-1]: @@ -140,15 +146,16 @@ def process_interval(start_date, end_date, output_fn, input_dir, sample_dt, mlat mad_data = get_mag_coords(apex, mad_data) logger.info(f"Setting up for binning, {mad_data.time.values[0]=}, {mad_data.time.values[-1]=}") - time_bins = np.arange(mad_data.time.values[0], mad_data.time.values[-1] + sample_dt, sample_dt) + time_bins = np.arange(ref_times[0], ref_times[-1] + 2 * sample_dt, sample_dt) data_groups = mad_data.groupby_bins('time', bins=time_bins, right=False) data = [_data for _interval, _data in data_groups] logger.info("Calculated bins") + h = 1 if hemisphere == 'north' else -1 tec = xr.DataArray( np.array([result for result in map(calc_bins, data)]), coords={ 'time': np.array([_interval.left for _interval, _data in data_groups]), - 'mlat': (mlat_bins[:-1] + mlat_bins[1:]) / 2, + 'mlat': h * (mlat_bins[:-1] + mlat_bins[1:]) / 2, 'mlt': (mlt_bins[:-1] + mlt_bins[1:]) / 2, }, dims=['time', 'mlat', 'mlt'] @@ -156,18 +163,7 @@ def process_interval(start_date, end_date, output_fn, input_dir, sample_dt, mlat tec.to_netcdf(output_fn) -def check_processed_data_interval(start, end, processed_file): - if processed_file.exists(): - logger.info(f"processed file already exists {processed_file=}, checking...") - try: - data_check = get_tec_data(start, end, processed_file.parent) - if not data_check.isnull().all(dim=['mlt', 'mlat']).any().item(): - logger.info(f"downloaded data already processed {processed_file=}, checking...") - return False - except Exception as e: - logger.info(f"error reading processed file {processed_file=}: {e}, removing and reprocessing") - processed_file.unlink() - return True +check_processed_data_interval = utils.get_data_checker(get_tec_data) def process_tec_dataset(start_date, end_date, download_dir=None, process_dir=None, dt=None, mlat_bins=None, @@ -184,14 +180,22 @@ def process_tec_dataset(start_date, end_date, download_dir=None, process_dir=Non dt = np.timedelta64(1, 'h') Path(process_dir).mkdir(exist_ok=True, parents=True) + logger.info(f"processing tec dataset over interval {start_date=} {end_date=}") for year in range(start_date.year, end_date.year + 1): + logger.info(f"tec year {year=}") for month in range(1, 13): - output_file = Path(process_dir) / f"tec_{year:04d}_{month:02d}.nc" start = datetime(year, month, 1) - end = datetime(year, month + 1, 1) if month < 12 else datetime(year + 1, 1, 1) + if month == 12: + end = utils.datetime64_to_datetime(np.datetime64(datetime(year + 1, 1, 1)) - dt) + else: + end = datetime(year, month + 1, 1) + logger.info(f"tec interval {start=} {end=}") if start >= end_date or end <= start_date: continue - start = max(start_date, start) - end = min(end_date, end) - if check_processed_data_interval(start, end, output_file): - process_interval(start, end, output_file, download_dir, dt, mlat_bins, mlt_bins) + for hemisphere in ['north', 'south']: + output_file = Path(process_dir) / f"tec_{hemisphere}_{year:04d}_{month:02d}.nc" + start = max(start_date, start) + end = min(end_date, end) + logger.info(f"reduced tec interval {start=} {end=} {hemisphere=}") + if check_processed_data_interval(start, end, dt, hemisphere, output_file): + process_interval(start, end, hemisphere, output_file, download_dir, dt, mlat_bins, mlt_bins) diff --git a/trough/_trough.py b/trough/_trough.py index 14e9b03..f9530d5 100644 --- a/trough/_trough.py +++ b/trough/_trough.py @@ -13,8 +13,8 @@ import pandas from skimage import measure, morphology from sklearn.metrics.pairwise import rbf_kernel -except ImportError as e: - warnings.warn(f"Packages required for recreating dataset not installed: {e}") +except ImportError as imp_err: + warnings.warn(f"Packages required for recreating dataset not installed: {imp_err}") from trough import config, utils, _tec, _arb @@ -22,17 +22,22 @@ logger = logging.getLogger(__name__) -def get_model(tec_data, omni_file): +def get_model(tec_data, hemisphere, omni_file): """Get magnetic latitudes of the trough according to the model in Deminov 2017 for a specific time and set of magnetic local times. """ + logger.info("getting model") omni_data = xr.open_dataset(omni_file) + logger.info(f"{omni_data.time.values[0]=} {omni_data.time.values[-1]=}") kp = _get_weighted_kp(tec_data.time, omni_data) + logger.info(f"{kp.shape=}") apex = Apex(date=utils.datetime64_to_datetime(tec_data.time.values[0])) mlat = 65.5 * np.ones((tec_data.time.shape[0], tec_data.mlt.shape[0])) + if hemisphere == 'south': + mlat = mlat * -1 for i in range(10): glat, glon = apex.convert(mlat, tec_data.mlt.values[None, :], 'mlt', 'geo', 350, tec_data.time.values[:, None]) - mlat = _model_subroutine_lat(tec_data.mlt.values[None, :], glon, kp[:, None]) + mlat = _model_subroutine_lat(tec_data.mlt.values[None, :], glon, kp[:, None], hemisphere) tec_data['model'] = xr.DataArray( mlat, coords={'time': tec_data.time, 'mlt': tec_data.mlt}, @@ -40,7 +45,7 @@ def get_model(tec_data, omni_file): ) -def _model_subroutine_lat(mlt, glon, kp): +def _model_subroutine_lat(mlt, glon, kp, hemisphere): """Get's model output mlat given MLT, geographic lon and weighted kp Parameters @@ -54,13 +59,19 @@ def _model_subroutine_lat(mlt, glon, kp): mlat: numpy.ndarray (n_t, n_mlt) """ phi_t = 3.16 - 5.6 * np.cos(np.deg2rad(15 * (mlt - 2.4))) + 1.4 * np.cos(np.deg2rad(15 * (2 * mlt - .8))) - phi_lon = .85 * np.cos(np.deg2rad(glon + 63)) - .52 * np.cos(np.deg2rad(2 * glon + 5)) + if hemisphere == 'north': + phi_lon = .85 * np.cos(np.deg2rad(glon + 63)) - .52 * np.cos(np.deg2rad(2 * glon + 5)) + elif hemisphere == 'south': + phi_lon = 1.5 * np.cos(np.deg2rad(glon - 119)) + else: + raise ValueError(f"Invalid hemisphere: {hemisphere}, valid = ['north', 'south']") return 65.5 - 2.4 * kp + phi_t + phi_lon * np.exp(-.3 * kp) def _get_weighted_kp(times, omni_data, tau=.6, T=10): """Get a weighed sum of kp values over time. See paper for details. """ + logger.info(f"_get_weighted_kp {times[0]=} {times[-1]=}") ap = omni_data.sel(time=slice(times[0] - np.timedelta64(T, 'h'), times[-1]))['ap'].values prehistory = np.column_stack([ap[T - i:ap.shape[0] - i] for i in range(T)]) weight_factors = tau ** np.arange(T) @@ -86,6 +97,7 @@ def estimate_background(tec, patch_shape): def preprocess_interval(data, min_val=0, max_val=100, bg_est_shape=(1, 15, 15)): + logger.info("preprocessing interval") tec = data['tec'].values # throw away outlier values tec[tec > max_val] = np.nan @@ -115,16 +127,21 @@ def fix_boundaries(labels): return fixed -def remove_auroral(data, offset=3): - data['labels'] *= data.mlat < (data['arb'] + offset) +def remove_auroral(data, hemisphere, offset=3): + if hemisphere == 'north': + data['labels'] *= data.mlat < (data['arb'] + offset) + elif hemisphere == 'south': + data['labels'] *= data.mlat > (data['arb'] - offset) + else: + raise ValueError(f"Invalid hemisphere: {hemisphere}, valid = ['north', 'south']") -def postprocess(data, perimeter_th=50, area_th=1, closing_r=0): +def postprocess(data, hemisphere, perimeter_th=50, area_th=1, closing_r=0): if closing_r > 0: selem = morphology.disk(closing_r, dtype=bool)[:, :, None] data['labels'] = np.pad(data['labels'], ((0, 0), (0, 0), (closing_r, closing_r)), 'wrap') data['labels'] = morphology.binary_closing(data['labels'], selem)[:, :, closing_r:-closing_r] - remove_auroral(data) + remove_auroral(data, hemisphere) for t in range(data.time.shape[0]): tmap = data['labels'][t].values labeled = measure.label(tmap, connectivity=2) @@ -183,6 +200,8 @@ def get_optimization_args(data, model_weight_max, rbf_bw, tv_hw, tv_vw, l2_weigh l2 = (model_weight_max - 1) * l2 / l2.max() + 1 l2 *= l2_weight fin_mask = np.isfinite(np.ravel(data['x'].isel(time=i))) + if not fin_mask.any(): + raise Exception("WHY ALL NAN??") args = (cp.Variable(data.mlat.shape[0] * data.mlt.shape[0]), basis[fin_mask, :], np.ravel(data['x'].isel(time=i))[fin_mask], tv, np.ravel(l2)) all_args.append(args) @@ -214,13 +233,13 @@ def run_multiple(args, parallel=True): return np.stack(results, axis=0) -def label_trough_interval(start_date, end_date, params, tec_dir, arb_dir, omni_file): +def label_trough_interval(start_date, end_date, params, hemisphere, tec_dir, arb_dir, omni_file): logger.info(f"labeling trough interval: {start_date=} {end_date=}") - data = _tec.get_tec_data(start_date, end_date, tec_dir).to_dataset(name='tec') + data = _tec.get_tec_data(start_date, end_date, hemisphere, tec_dir).to_dataset(name='tec') preprocess_interval(data, bg_est_shape=params.bg_est_shape) - data['arb'] = _arb.get_arb_data(start_date, end_date, arb_dir) - get_model(data, omni_file) + data['arb'] = _arb.get_arb_data(start_date, end_date, hemisphere, arb_dir) + get_model(data, hemisphere, omni_file) args = get_optimization_args(data, params.model_weight_max, params.rbf_bw, params.tv_hw, params.tv_vw, params.l2_weight, params.tv_weight) logger.info("Running inversion optimization") @@ -237,7 +256,7 @@ def label_trough_interval(start_date, end_date, params, tec_dir, arb_dir, omni_f data['labels'] = data['score'] >= params.threshold # postprocess logger.info("Postprocessing inversion results") - postprocess(data, params.perimeter_th, params.area_th, params.closing_rad) + postprocess(data, hemisphere, params.perimeter_th, params.area_th, params.closing_rad) return data @@ -256,44 +275,47 @@ def label_trough_dataset(start_date, end_date, params=None, tec_dir=None, arb_di Path(output_dir).mkdir(exist_ok=True, parents=True) for year in range(start_date.year, end_date.year + 1): - labels = [] - scores = [] - start = datetime(year, 1, 1, 0, 0) - while start.year < year + 1: - end = start + timedelta(days=1) - if start >= end_date or end <= start_date: + for hemisphere in ['north', 'south']: + labels = [] + scores = [] + start = datetime(year, 1, 1, 0, 0) + while start.year < year + 1: + end = start + timedelta(days=1) + if start >= end_date or end <= start_date: + start += timedelta(days=1) + continue + start = max(start_date, start) + end = min(end_date, end) + data = label_trough_interval(start, end, params, hemisphere, tec_dir, arb_dir, omni_file) + if end.year == start.year + 1: + data = data.isel(time=slice(0, -1)) + labels.append(data['labels']) + scores.append(data['score']) start += timedelta(days=1) - continue - start = max(start_date, start) - end = min(end_date, end) - data = label_trough_interval(start, end, params, tec_dir, arb_dir, omni_file) - labels.append(data['labels']) - scores.append(data['score']) - start += timedelta(days=1) - labels = xr.concat(labels, 'time') - scores = xr.concat(scores, 'time') - labels.to_netcdf(Path(output_dir) / f"labels_{year:04d}.nc") - scores.to_netcdf(Path(output_dir) / f"scores_{year:04d}.nc") - - -def get_label_paths(start_date, end_date, processed_dir): + labels = xr.concat(labels, 'time') + scores = xr.concat(scores, 'time') + labels.to_netcdf(Path(output_dir) / f"labels_{hemisphere}_{year:04d}.nc") + scores.to_netcdf(Path(output_dir) / f"scores_{hemisphere}_{year:04d}.nc") + + +def get_label_paths(start_date, end_date, hemisphere, processed_dir): file_dates = np.arange( np.datetime64(start_date, 'Y'), - np.datetime64(end_date, 'Y') + 1, + (np.datetime64(end_date, 's')).astype('datetime64[Y]') + 1, np.timedelta64(1, 'Y') ) file_dates = utils.decompose_datetime64(file_dates) - return [Path(processed_dir) / f"labels_{d[0]:04d}.nc" for d in file_dates] + return [Path(processed_dir) / f"labels_{hemisphere}_{d[0]:04d}.nc" for d in file_dates] -def get_trough_labels(start_date, end_date, labels_dir=None): +def get_trough_labels(start_date, end_date, hemisphere, labels_dir=None): if labels_dir is None: labels_dir = config.processed_labels_dir - data = xr.concat([xr.open_dataarray(file) for file in get_label_paths(start_date, end_date, labels_dir)], 'time') + data = xr.concat([xr.open_dataarray(file) for file in get_label_paths(start_date, end_date, hemisphere, labels_dir)], 'time') return data.sel(time=slice(start_date, end_date)) -def get_data(start_date, end_date, tec_dir=None, omni_file=None, labels_dir=None): +def get_data(start_date, end_date, hemisphere, tec_dir=None, omni_file=None, labels_dir=None): if tec_dir is None: tec_dir = config.processed_tec_dir if omni_file is None: @@ -301,6 +323,6 @@ def get_data(start_date, end_date, tec_dir=None, omni_file=None, labels_dir=None if labels_dir is None: labels_dir = config.processed_labels_dir data = xr.open_dataset(omni_file).sel(time=slice(start_date, end_date)) - data['tec'] = _tec.get_tec_data(start_date, end_date, tec_dir) - data['labels'] = get_trough_labels(start_date, end_date, labels_dir) + data['tec'] = _tec.get_tec_data(start_date, end_date, hemisphere, tec_dir) + data['labels'] = get_trough_labels(start_date, end_date, hemisphere, labels_dir) return data diff --git a/trough/scripts.py b/trough/scripts.py index d9bb686..18f3d72 100644 --- a/trough/scripts.py +++ b/trough/scripts.py @@ -74,28 +74,6 @@ def label_trough(start_date, end_date): _trough.label_trough_dataset(start_date, end_date) -def _tec_interval_check(start, end): - try: - data_check = _tec.get_tec_data(start, end) - if not data_check.isnull().all(dim=['mlt', 'mlat']).any().item(): - logger.info(f"tec data already processed {start=}, {end=}") - return False - except Exception as e: - logger.info(f"error getting data {start=}, {end=}: {e}, reprocessing") - return True - - -def _arb_interval_check(start, end): - try: - data_check = _arb.get_arb_data(start, end) - if not data_check.isnull().all(dim=['mlt']).any().item(): - logger.info(f"arb data already processed {start=}, {end=}") - return False - except Exception as e: - logger.info(f"error getting data {start=}, {end=}: {e}, reprocessing") - return True - - def full_run(start_date, end_date): logger.info("running 'full_run'") for year in range(start_date.year, end_date.year + 1): @@ -103,13 +81,10 @@ def full_run(start_date, end_date): end = min(end_date, datetime(year + 1, 1, 1)) if end - start <= timedelta(hours=1): continue - - if _tec_interval_check(start, end): - download_tec(start, end) - process_tec(start, end) - if _arb_interval_check(start, end): - download_arb(start, end) - process_arb(start, end) + download_tec(start, end) + process_tec(start, end) + download_arb(start, end) + process_arb(start, end) download_omni(start_date, end_date) process_omni(start_date, end_date) diff --git a/trough/utils.py b/trough/utils.py index 1fbfd59..d9a5040 100644 --- a/trough/utils.py +++ b/trough/utils.py @@ -1,11 +1,15 @@ import numpy as np import datetime import warnings +import logging try: import h5py from skimage.util import view_as_windows -except ImportError as e: - warnings.warn(f"Packages required for recreating dataset not installed: {e}") +except ImportError as imp_err: + warnings.warn(f"Packages required for recreating dataset not installed: {imp_err}") + + +logger = logging.getLogger(__name__) def datetime64_to_datetime(dt64): @@ -126,3 +130,26 @@ def write_h5(fn, **kwargs): with h5py.File(fn, 'w') as f: for key, value in kwargs.items(): f.create_dataset(key, data=value) + + +def get_data_checker(data_getter): + + def check(start, end, dt, hemisphere, processed_file): + times = np.arange(np.datetime64(start), np.datetime64(end), dt).astype('datetime64[s]') + if processed_file.exists(): + logger.info(f"processed file already exists {processed_file=}, checking...") + try: + data_check = data_getter(start, end, hemisphere, processed_file.parent) + data_check = data_check.sel(time=times) + has_missing_data = data_check.isnull().all(axis=[i for i in range(1, data_check.ndim)]).any() + if not has_missing_data: + logger.info(f"downloaded data already processed {processed_file=}, checking...") + return False + except KeyError: + logger.info(f"processed file doesn't have the requested data") + except Exception as e: + logger.info(f"error reading processed file {processed_file=}: {e}, removing and reprocessing") + processed_file.unlink() + return True + + return check