about summary refs log tree commit diff stats
path: root/discord/utils
diff options
context:
space:
mode:
Diffstat (limited to 'discord/utils')
-rw-r--r--discord/utils/event_emitter.py22
1 files changed, 8 insertions, 14 deletions
diff --git a/discord/utils/event_emitter.py b/discord/utils/event_emitter.py
index 77fff72..08b6060 100644
--- a/discord/utils/event_emitter.py
+++ b/discord/utils/event_emitter.py
@@ -1,29 +1,23 @@
 import asyncio
-from types import NoneType
-from typing import Coroutine
+from typing import Optional, Coroutine, Any, Callable, Dict
 
 class EventEmitter():
-    def __init__(self):
-        self.listeners = {}
+    def __init__(self, loop: Optional[asyncio.AbstractEventLoop]=None):
+        self.listeners: Dict[str, Optional[Callable[..., Coroutine[Any, Any, Any]]]] = {}
+        self.loop = loop if loop else asyncio.get_event_loop()
 
-    def add_listener(self, event_name: str, func: Coroutine):
+    def add_listener(self, event_name: str, func: Optional[Callable[..., Coroutine[Any, Any, Any]]]=None):
         if not self.listeners.get(event_name, None):
             self.listeners[event_name] = {func}
         else:
             self.listeners[event_name].add(func)
 
-    def remove_listener(self, event_name: str, func: Coroutine):
+    def remove_listener(self, event_name: str, func: Optional[Callable[..., Coroutine[Any, Any, Any]]]=None):
         self.listeners[event_name].remove(func)
         if len(self.listeners[event_name]) == 0:
             del self.listeners[event_name]
 
-    def emit(self, event_name: str, args_required=False, *args, **kwargs):
+    def emit(self, event_name: str, *args: Any, **kwargs: Any) -> None:
         listeners = self.listeners.get(event_name, [])
         for func in listeners:
-            if args_required:
-                if len(args) == 0:
-                    raise TypeError('event registered must have arguments')
-                else:
-                    asyncio.create_task(func(*args, **kwargs))
-            else:
-                asyncio.create_task(func(*args, **kwargs))
+            asyncio.run_coroutine_threadsafe(func(*args, **kwargs), self.loop)