From 01629ca407c04ce76301f1a170cafab7c352ab47 Mon Sep 17 00:00:00 2001 From: boke0 Date: Wed, 10 Mar 2021 17:12:59 +0900 Subject: [PATCH] [update] --- mitama/db/event.py | 64 +++++++++++++++++++++++++++++++-------------- mitama/db/model.py | 21 +++++++-------- tests/test_event.py | 41 +++++++++++++++++++++++++++++ 3 files changed, 96 insertions(+), 30 deletions(-) create mode 100644 tests/test_event.py diff --git a/mitama/db/event.py b/mitama/db/event.py index 2a30bf5..9f167da 100644 --- a/mitama/db/event.py +++ b/mitama/db/event.py @@ -1,32 +1,58 @@ -class Event: - def __init__(self, doc=None): +class EventManager: + def __init__(self, events=[], doc=None): self.__doc__ = doc + self.events = {} + for event in events: + self.events[event] = Event() def __get__(self, obj, objtype=None): if obj is None: - return Self - return EventHandler(self, obj) + return self + return EventManagerInstance(self, obj) def __set__(self, obj, objtype=None): pass -class EventHandler: - def __init__(self, event, obj): - self.event = event - self.obj = obj - def _getfunctionlist(self): - try: - eventhandler = self.obj.__eventhandler__ - except AttributeError: - eventhandler = self.obj.__eventhandler__ = {} - return eventhandler.setdefault(self.event, []) + def __getitem__(self, key): + return self.events[key] + + def __setitem__(self, key, value): + self.events[key] = value + + def listen(self, event): + self.events[event] = Event() + + +class EventManagerInstance: + def __init__(self, event_manager, obj): + self.event_manager = event_manager + self.object = obj + + def __getitem__(self, key): + return self.event_manager.events[key].handler(self.object) + + +class Event: + def __init__(self): + self._funcs = [] + def __iadd__(self, func): - self._getfunctionlist().append(func) + self._funcs.append(func) return self + def __isub__(self, func): - self._getfunctionlist().remove(func) + self._funcs.remove(func) return self - def __call__(self,earg=None): - for func in self._getfunctionlist(): - func(self.obj, earg) + def handler(self, obj): + return EventHandler(self, obj) + + +class EventHandler: + def __init__(self, event, obj): + self.event = event + self.object = obj + + def __call__(self, **kwargs): + for f in self.event._funcs: + f(self.object, **kwargs) diff --git a/mitama/db/model.py b/mitama/db/model.py index d87f148..e9514aa 100755 --- a/mitama/db/model.py +++ b/mitama/db/model.py @@ -17,7 +17,7 @@ from mitama._extra import _classproperty, tosnake from .types import Column, Integer, String -from .event import Event +from .event import EventManager def UUID(prefix=None): @@ -32,9 +32,11 @@ def genUUID(): class Model: prefix = None _id = Column(String(64), default=UUID(), primary_key=True, nullable=False) - create = Event() - update = Event() - delete = Event() + event = EventManager([ + "create", + "update", + "delete" + ]) @classmethod def attribute_names(cls): @@ -84,31 +86,28 @@ def create(self): self.query.session.add(self) self.query.session.commit() try: - self.on("create")() + self.event["create"]() except Exception: pass def update(self): self.query.session.commit() try: - self.on("update")() + self.event["update"]() except Exception: pass def delete(self): try: - self.on("delete")() + self.event["delete"]() except Exception: pass self.query.session.delete(self) self.query.session.commit() - def on(self, evt): - return getattr(self, evt) - @classmethod def listen(cls, evt): - setattr(cls, evt, Event()) + cls.event.listen(evt) @classmethod def list(cls, *args): diff --git a/tests/test_event.py b/tests/test_event.py new file mode 100644 index 0000000..17b0619 --- /dev/null +++ b/tests/test_event.py @@ -0,0 +1,41 @@ +import unittest + +from mitama.db import DatabaseManager, BaseDatabase +from mitama.db.types import Column, String + +DatabaseManager.test() + + +class Database(BaseDatabase): + pass + + +db = Database() + + +class ModelC(db.Model): + name = Column(String) + + +def addhoge(modelc): + modelc.name = "huga" + modelc.update() + + +ModelC.listen("hoge") +ModelC.event["hoge"] += addhoge + + +db.create_all() + + +class TestEvent(unittest.TestCase): + def test_event(self): + DatabaseManager.start_session() + hoge = ModelC() + hoge.name = "hoge" + hoge.create() + hoge.event["hoge"]() + self.assertEqual(hoge.name, "huga") + DatabaseManager.close_session() + pass