about summary refs log tree commit diff stats
path: root/discord
diff options
context:
space:
mode:
authormjk134 <57556877+mjk134@users.noreply.github.com>2022-07-09 14:05:58 +0000
committerGitHub <noreply@github.com>2022-07-09 14:05:58 +0000
commit56cff43ab5e5c6f0f0bd8ce714f830cb6e129327 (patch)
treef9aa81159d01ace434a7f0677904164bc0be9b1b /discord
parenta72d0da957af970c1608061d5dcfbf4df3975ffd (diff)
downloaddiscobra-56cff43ab5e5c6f0f0bd8ce714f830cb6e129327.tar.gz
feat: Added poll event loop
Diffstat (limited to 'discord')
-rw-r--r--discord/client.py120
1 files changed, 88 insertions, 32 deletions
diff --git a/discord/client.py b/discord/client.py
index 0ac042e..d272130 100644
--- a/discord/client.py
+++ b/discord/client.py
@@ -1,45 +1,48 @@
 import asyncio
+from enum import IntEnum
 import json
 import sys
 import threading
 from typing import Optional, Coroutine, Any, Callable
+import zlib
 import websockets
 
 from .utils import EventEmitter
 from .intents import Intents, gen_number
 from .user import User
 
-
+class GatewayEvents(IntEnum):
+    DISPATCH           = 0
+    HEARTBEAT          = 1
+    IDENTIFY           = 2
+    PRESENCE           = 3
+    VOICE_STATE        = 4
+    VOICE_PING         = 5
+    RESUME             = 6
+    RECONNECT          = 7
+    REQUEST_MEMBERS    = 8
+    INVALIDATE_SESSION = 9
+    HELLO              = 10
+    HEARTBEAT_ACK      = 11
+    GUILD_SYNC         = 12
 class Client:
     def __init__(self, intents: list[Intents]):
         self.gateway = None
         self.loop = asyncio.get_event_loop()
-        self.code = gen_number(intents)
+        self.code: int = gen_number(intents)
         self.event_emitter = EventEmitter()
+        self.buffer = bytearray()
+        self.inflator = zlib.decompressobj()
+        self.heartbeat_interval: int = None
+        self.token: str = None
+        self.ready: bool = False
 
     async def connect(self, token: str, intent_code: int):
         async with websockets.connect("wss://gateway.discord.gg/?v=10&encoding=json") as gateway:
-            hello = await gateway.recv()
             self.gateway = gateway
             threading.Thread(target=self.loop.run_forever).start()
-            heartbeat = asyncio.run_coroutine_threadsafe(
-                self.heartbeat(gateway, json.loads(hello)['d']['heartbeat_interval']), self.loop)
-            identify = {
-                "op": 2,
-                "d": {
-                    "token": token,
-                    "intents": intent_code,
-                    "properties": {
-                        "os": sys.platform,
-                        "browser": "discobra",
-                        "device": "discobra"
-                    }
-                }
-            }
-            await gateway.send(json.dumps(identify))
-            ready = await gateway.recv()
-            self.event_emitter.emit('on_ready')
-            self.user = User(json.loads(ready)['d']['user'])
+            while True:
+                await self.poll_event()
     
     async def send(self, data: dict):
         """
@@ -51,7 +54,41 @@ class Client:
         """
         Receive data from the gateway.
         """
-        pass
+        if type(msg) is bytes:
+            self.buffer.extend(msg)
+            if len(msg) < 4 or msg[-4:] != b'\x00\x00\xff\xff':
+                return
+
+            msg = self.inflator.decompress(self.buffer)
+            msg.decode('utf-8')
+            self.buffer = bytearray()
+        msg = json.loads(msg)
+        opcode = msg['op']
+        data = msg['d']
+        sequence = msg['s']
+
+        if opcode != GatewayEvents.DISPATCH.value:
+            if opcode == GatewayEvents.RECONNECT.value:
+                await self.gateway.close()
+
+            if opcode == GatewayEvents.HELLO.value:
+                self.heartbeat_interval = data['heartbeat_interval']
+                asyncio.run_coroutine_threadsafe(self.heartbeat(self.heartbeat_interval), self.loop)
+                return await self.identify()
+
+            if opcode == GatewayEvents.HEARTBEAT_ACK.value:
+                return await self.heartbeat(self.heartbeat_interval)
+            
+            if opcode == GatewayEvents.HEARTBEAT.value:
+                return await self.heartbeat(self.heartbeat_interval)
+
+        event = msg['t']
+
+        if event == 'READY':
+            self.user = User(data['user'])
+
+        self.event_emitter.emit('on_' + event.lower())
+
 
     async def close(self):
         """
@@ -60,18 +97,36 @@ class Client:
         self.loop.stop()
         await self.gateway.close()
 
-    async def poll_events(self):
-        pass
+    async def poll_event(self):
+        msg = await self.gateway.recv()
+        await self.recv(msg)
+    
+
+    async def heartbeat(self, interval: int):
+        await asyncio.sleep(interval / 1000)
+        heartbeat = {
+            "op": 1,
+            "d": None
+        }
+        await self.gateway.send(json.dumps(heartbeat))
 
-    async def heartbeat(self, gateway: websockets.WebSocketClientProtocol, interval: int):
-        while True:
-            await asyncio.sleep(interval / 1000)
-            heartbeat = {
-                "op": 1,
-                "d": None
+    async def identify(self):
+        """
+        Identify the client.
+        """
+        identify = {
+            "op": GatewayEvents.IDENTIFY,
+            "d": {
+                "token": self.token,
+                "intents": self.code,
+                "properties": {
+                    "os": sys.platform,
+                    "browser": "discobra",
+                    "device": "discobra"
+                }
             }
-            await gateway.send(json.dumps(heartbeat))
-            ack = await gateway.recv()
+        }
+        await self.gateway.send(json.dumps(identify))
 
     def event(self, coro: Optional[Callable[..., Coroutine[Any, Any, Any]]]=None, /) -> Optional[Callable[..., Coroutine[Any, Any, Any]]]:
         """
@@ -86,4 +141,5 @@ class Client:
         """
         Run the client.
         """
+        self.token = token
         asyncio.run(self.connect(token, self.code))