about summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--discord/client.py28
-rw-r--r--discord/utils/event_emitter.py22
2 files changed, 33 insertions, 17 deletions
diff --git a/discord/client.py b/discord/client.py
index 4c6be14..0ac042e 100644
--- a/discord/client.py
+++ b/discord/client.py
@@ -2,7 +2,7 @@ import asyncio
 import json
 import sys
 import threading
-from typing import Coroutine
+from typing import Optional, Coroutine, Any, Callable
 import websockets
 
 from .utils import EventEmitter
@@ -38,8 +38,30 @@ class Client:
             }
             await gateway.send(json.dumps(identify))
             ready = await gateway.recv()
-            self.event_emitter.emit('on_ready', False)
+            self.event_emitter.emit('on_ready')
             self.user = User(json.loads(ready)['d']['user'])
+    
+    async def send(self, data: dict):
+        """
+        Send data to the gateway.
+        """
+        await self.gateway.send(json.dumps(data))
+    
+    async def recv(self, msg):
+        """
+        Receive data from the gateway.
+        """
+        pass
+
+    async def close(self):
+        """
+        Close the client.
+        """
+        self.loop.stop()
+        await self.gateway.close()
+
+    async def poll_events(self):
+        pass
 
     async def heartbeat(self, gateway: websockets.WebSocketClientProtocol, interval: int):
         while True:
@@ -51,7 +73,7 @@ class Client:
             await gateway.send(json.dumps(heartbeat))
             ack = await gateway.recv()
 
-    def event(self, coro: Coroutine, /) -> Coroutine:
+    def event(self, coro: Optional[Callable[..., Coroutine[Any, Any, Any]]]=None, /) -> Optional[Callable[..., Coroutine[Any, Any, Any]]]:
         """
         Registers a coroutine to be called when an event is emitted.
         """
diff --git a/discord/utils/event_emitter.py b/discord/utils/event_emitter.py
index 77fff72..08b6060 100644
--- a/discord/utils/event_emitter.py
+++ b/discord/utils/event_emitter.py
@@ -1,29 +1,23 @@
 import asyncio
-from types import NoneType
-from typing import Coroutine
+from typing import Optional, Coroutine, Any, Callable, Dict
 
 class EventEmitter():
-    def __init__(self):
-        self.listeners = {}
+    def __init__(self, loop: Optional[asyncio.AbstractEventLoop]=None):
+        self.listeners: Dict[str, Optional[Callable[..., Coroutine[Any, Any, Any]]]] = {}
+        self.loop = loop if loop else asyncio.get_event_loop()
 
-    def add_listener(self, event_name: str, func: Coroutine):
+    def add_listener(self, event_name: str, func: Optional[Callable[..., Coroutine[Any, Any, Any]]]=None):
         if not self.listeners.get(event_name, None):
             self.listeners[event_name] = {func}
         else:
             self.listeners[event_name].add(func)
 
-    def remove_listener(self, event_name: str, func: Coroutine):
+    def remove_listener(self, event_name: str, func: Optional[Callable[..., Coroutine[Any, Any, Any]]]=None):
         self.listeners[event_name].remove(func)
         if len(self.listeners[event_name]) == 0:
             del self.listeners[event_name]
 
-    def emit(self, event_name: str, args_required=False, *args, **kwargs):
+    def emit(self, event_name: str, *args: Any, **kwargs: Any) -> None:
         listeners = self.listeners.get(event_name, [])
         for func in listeners:
-            if args_required:
-                if len(args) == 0:
-                    raise TypeError('event registered must have arguments')
-                else:
-                    asyncio.create_task(func(*args, **kwargs))
-            else:
-                asyncio.create_task(func(*args, **kwargs))
+            asyncio.run_coroutine_threadsafe(func(*args, **kwargs), self.loop)