Skip to content

Commit

Permalink
Add type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
kernitus committed Oct 29, 2023
1 parent bbc6201 commit 3be815a
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 40 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/python-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@ jobs:
- name: Install package and dependencies
run: |
python -m pip install --upgrade pip
pip install flake8
pip install flake8 mypy types-python-dateutil
pip install .
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Type checking with mypy
run: mypy --ignore-missing-imports --strict .
- name: Test with unittest
run: python -m unittest discover -s tests
2 changes: 1 addition & 1 deletion LICENSE → LICENCE
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
MIT License
MIT Licence

Copyright (c) 2020 kernitus

Expand Down
8 changes: 6 additions & 2 deletions beetsplug/date_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import datetime
from typing import Optional

from dateutil import parser


Expand All @@ -9,7 +11,8 @@ class DateWrapper(datetime.datetime):
with the month and day being optional.
"""

def __new__(cls, y: int = None, m: int = None, d: int = None, iso_string: str = None):
def __new__(cls, y: Optional[int] = None, m: Optional[int] = None, d: Optional[int] = None,
iso_string: Optional[str] = None):
"""
Create a new datetime object using a convenience wrapper.
Must specify at least one of either year or iso_string.
Expand Down Expand Up @@ -38,7 +41,8 @@ def today(cls):
today = datetime.date.today()
return DateWrapper(today.year, today.month, today.day)

def __init__(self, y=None, m=None, d=None, iso_string=None):
def __init__(self, y: Optional[int] = None, m: Optional[int] = None, d: Optional[int] = None,
iso_string: Optional[str] = None):
if y is not None:
self.y = min(max(y, datetime.MINYEAR), datetime.MAXYEAR)
self.m = m if (m is None or 0 < m <= 12) else 1
Expand Down
79 changes: 43 additions & 36 deletions beetsplug/oldestdate.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Optional, Any
import mediafile
import musicbrainzngs
from beets import ui, config
from beets.autotag import hooks, TrackInfo
from beets.importer import action
from beets.library import Item
from beets.importer import action, ImportTask, ImportSession
from beets.library import Item, Library
from beets.plugins import BeetsPlugin

from .date_wrapper import DateWrapper
Expand All @@ -14,9 +15,12 @@
"https://github.com/kernitus/beets-oldestdate"
)

# Type alias
Recording = dict[str, Any]

# Extract first valid work_id from recording
def _get_work_id_from_recording(recording):

def _get_work_id_from_recording(recording: Recording) -> Optional[str]:
"""Extract first valid work_id from recording"""
work_id = None

if 'work-relation-list' in recording:
Expand All @@ -30,8 +34,8 @@ def _get_work_id_from_recording(recording):
return work_id


# Returns whether this recording contains at least one of the specified artists
def _contains_artist(recording, artist_ids):
def _contains_artist(recording: Recording, artist_ids: list[str]) -> bool:
"""Returns whether this recording contains at least one of the specified artists"""
artist_found = False
if 'artist-credit' in recording:
for artist in recording['artist-credit']:
Expand All @@ -43,8 +47,8 @@ def _contains_artist(recording, artist_ids):
return artist_found


# Extract artist ids from a recording
def _get_artist_ids_from_recording(recording):
def _get_artist_ids_from_recording(recording: Recording) -> list[str]:
"""Extract artist ids from a recording"""
ids = []

if 'artist-credit' in recording:
Expand All @@ -56,8 +60,8 @@ def _get_artist_ids_from_recording(recording):
return ids


# Returns whether given fetched recording is a cover of a work
def _is_cover(recording):
def _is_cover(recording: Recording) -> bool:
"""Returns whether given fetched recording is a cover of a work"""
if 'work-relation-list' in recording:
for work in recording['work-relation-list']:
if 'attribute-list' in work:
Expand All @@ -66,14 +70,14 @@ def _is_cover(recording):
return False


# Fetch work, including recording relations
def _fetch_work(work_id):
def _fetch_work(work_id: str) -> Recording:
"""Fetch work, including recording relations"""
return musicbrainzngs.get_work_by_id(work_id, ['recording-rels'])['work']


class OldestDatePlugin(BeetsPlugin):
_importing = False
_recordings_cache = dict()
_importing: bool = False
_recordings_cache: dict[str, Recording] = dict()

def __init__(self):
super(OldestDatePlugin, self).__init__()
Expand Down Expand Up @@ -126,12 +130,12 @@ def commands(self):
recording_date_command.func = self._command_func
return [recording_date_command]

# Fetch the recording associated with each candidate
def _import_trackinfo(self, info):
def _import_trackinfo(self, info: TrackInfo) -> None:
"""Fetch the recording associated with each candidate"""
if 'track_id' in info:
self._fetch_recording(info.track_id)

def track_distance(self, _, info: TrackInfo):
def track_distance(self, _: Item, info: TrackInfo) -> hooks.Distance:
dist = hooks.Distance()
if info.data_source != 'MusicBrainz':
self._log.debug('Skipping track with non MusicBrainz data source {0.artist} - {0.title}', info)
Expand All @@ -141,10 +145,10 @@ def track_distance(self, _, info: TrackInfo):

return dist

def _import_task_created(self, task, session):
def _import_task_created(self, task: ImportTask, _: ImportSession) -> None:
task.item.mb_trackid = None

def _import_task_choice(self, task, session):
def _import_task_choice(self, task: ImportTask, _: ImportSession) -> None:
match = task.match
if not match:
return
Expand Down Expand Up @@ -174,24 +178,24 @@ def _import_task_choice(self, task, session):
task.choice_flag = action.SKIP
return

# Return whether the recording has a work id
def _has_work_id(self, recording_id):
def _has_work_id(self, recording_id: str) -> bool:
"""Return whether the recording has a work id"""
recording = self._get_recording(recording_id)
work_id = _get_work_id_from_recording(recording)
return work_id is not None

# This queries the local database, not the files.
def _command_func(self, lib, session, args):
def _command_func(self, lib: Library, _: ImportSession, args: list[str]) -> None:
"""This queries the local database, not the files."""
for item in lib.items(args):
self._process_file(item)

def _on_import(self, session, task):
def _on_import(self, _: ImportSession, task: ImportTask) -> None:
if self.config['auto']:
self._importing = True
for item in task.imported_items():
self._process_file(item)

def _process_file(self, item: Item):
def _process_file(self, item: Item) -> None:
if not item.mb_trackid or item.data_source != 'MusicBrainz':
self._log.info('Skipping track with no mb_trackid: {0.artist} - {0.title}', item)
return
Expand Down Expand Up @@ -234,19 +238,20 @@ def _process_file(self, item: Item):
if not self._importing:
item.write()

# Fetch and cache recording from MusicBrainz, including releases and work relations
def _fetch_recording(self, recording_id):
def _fetch_recording(self, recording_id: str) -> Recording:
"""Fetch and cache recording from MusicBrainz, including releases and work relations"""
recording = musicbrainzngs.get_recording_by_id(recording_id, ['artists', 'releases', 'work-rels'])['recording']
self._recordings_cache[recording_id] = recording
return recording

# Get recording from cache or MusicBrainz
def _get_recording(self, recording_id):
def _get_recording(self, recording_id: str) -> Recording:
"""Get recording from cache or MusicBrainz"""
return self._recordings_cache[
recording_id] if recording_id in self._recordings_cache else self._fetch_recording(recording_id)

# Get oldest date from a recording
def _extract_oldest_recording_date(self, recordings, starting_date, is_cover, approach):
def _extract_oldest_recording_date(self, recordings: list[Recording], starting_date: DateWrapper,
is_cover: bool, approach: str) -> DateWrapper:
"""Get oldest date from a recording"""
oldest_date = starting_date

for rec in recordings:
Expand Down Expand Up @@ -279,8 +284,9 @@ def _extract_oldest_recording_date(self, recordings, starting_date, is_cover, ap

return oldest_date

# Get oldest date from a release
def _extract_oldest_release_date(self, recordings, starting_date, is_cover, artist_ids):
def _extract_oldest_release_date(self, recordings: list[Recording], starting_date: DateWrapper,
is_cover: bool, artist_ids: list[str]) -> DateWrapper:
"""Get oldest date from a release"""
oldest_date = starting_date
release_types = self.config['release_types'].get()

Expand Down Expand Up @@ -328,8 +334,9 @@ def _extract_oldest_release_date(self, recordings, starting_date, is_cover, arti

return oldest_date

# Iterates through a list of recordings and returns oldest date
def _iterate_dates(self, recordings, starting_date, is_cover, artist_ids):
def _iterate_dates(self, recordings: list[Recording], starting_date: DateWrapper,
is_cover: bool, artist_ids: list[str]) -> Optional[DateWrapper]:
"""Iterates through a list of recordings and returns oldest date"""
approach = self.config['approach'].get()
oldest_date = starting_date

Expand All @@ -343,7 +350,7 @@ def _iterate_dates(self, recordings, starting_date, is_cover, artist_ids):

return None if oldest_date == DateWrapper.today() else oldest_date

def _get_oldest_date(self, recording_id, item_date):
def _get_oldest_date(self, recording_id: str, item_date: Optional[DateWrapper]) -> Optional[DateWrapper]:
recording = self._get_recording(recording_id)
is_cover = _is_cover(recording)
work_id = _get_work_id_from_recording(recording)
Expand Down

0 comments on commit 3be815a

Please sign in to comment.