about summary refs log tree commit diff stats
path: root/discord
diff options
context:
space:
mode:
authorNoah <mounderfod@gmail.com>2022-07-08 21:22:26 +0100
committerNoah <mounderfod@gmail.com>2022-07-08 21:23:12 +0100
commitddc5a3d3fed034a04bedb24cadc5f592daff4a27 (patch)
tree9683b5b44825199588acbef67e2741d9f2c6a937 /discord
parent74d5b1199adc085d39417b8e1233a8730394d42e (diff)
parent99da423d99378eaf446031451f9f6de76010d7c0 (diff)
downloaddiscobra-ddc5a3d3fed034a04bedb24cadc5f592daff4a27.tar.gz
chore: Merge branch 'master' of https://github.com/mounderfod/discobra
 Conflicts:
	discord/client.py
Diffstat (limited to 'discord')
-rw-r--r--discord/client.py29
1 files changed, 19 insertions, 10 deletions
diff --git a/discord/client.py b/discord/client.py
index 73fe0e3..c500d1c 100644
--- a/discord/client.py
+++ b/discord/client.py
@@ -2,27 +2,23 @@ import asyncio
 import json
 import sys
 import threading
-
 import websockets
-
+from typing import Coroutine
 from discord.intents import Intents, gen_number
 
-loop = asyncio.get_event_loop()
-
-
 class Client:
-    def __init__(self, token: str, intents: list[Intents]):
+    def __init__(self, intents: list[Intents]):
         self.gateway = None
-        code = gen_number(intents)
-        asyncio.run(self.connect(token, code))
+        self.loop = asyncio.get_event_loop()
+        self.code = gen_number(intents)
 
     async def connect(self, token: str, intent_code: int):
         async with websockets.connect("wss://gateway.discord.gg/?v=10&encoding=json") as gateway:
             hello = await gateway.recv()
             self.gateway = gateway
-            threading.Thread(target=loop.run_forever).start()
+            threading.Thread(target=self.loop.run_forever).start()
             heartbeat = asyncio.run_coroutine_threadsafe(
-                self.heartbeat(gateway, json.loads(hello)['d']['heartbeat_interval']), loop)
+                self.heartbeat(gateway, json.loads(hello)['d']['heartbeat_interval']), self.loop)
             identify = {
                 "op": 2,
                 "d": {
@@ -37,6 +33,8 @@ class Client:
             }
             await gateway.send(json.dumps(identify))
             ready = await gateway.recv()
+            if (hasattr(self, 'on_ready')):
+                await getattr(self, 'on_ready')()
 
     async def heartbeat(self, gateway: websockets.WebSocketClientProtocol, interval: int):
         while True:
@@ -47,3 +45,14 @@ class Client:
             }
             await gateway.send(json.dumps(heartbeat))
             ack = await gateway.recv()
+
+    def event(self, coro: Coroutine, /) -> Coroutine:
+        if not asyncio.iscoroutinefunction(coro):
+            raise TypeError('event registered must be a coroutine function')
+
+        setattr(self, coro.__name__, coro)
+        return coro
+
+    def run(self, token: str):
+        self.token = token
+        asyncio.run(self.connect(self.token, self.code))