about summary refs log tree commit diff stats
path: root/discord
diff options
context:
space:
mode:
authorNoah <mounderfod@gmail.com>2022-07-09 18:59:36 +0100
committerNoah <mounderfod@gmail.com>2022-07-09 18:59:36 +0100
commit57fff05c0e2058855946ffa4e5c02872c534ee65 (patch)
tree41faac71321b81f0ac34083f20e57ce461ba46f6 /discord
parent5058a47b28fa1f0a5982a65c2a65955ea158a0e4 (diff)
parent55fe04a1d73625600c87e41fe45e5a713bda0aba (diff)
downloaddiscobra-57fff05c0e2058855946ffa4e5c02872c534ee65.tar.gz
Merge branch 'master' of https://github.com/mounderfod/discobra
 Conflicts:
	discord/client.py
Diffstat (limited to 'discord')
-rw-r--r--discord/client.py131
-rw-r--r--discord/utils/rest.py41
2 files changed, 109 insertions, 63 deletions
diff --git a/discord/client.py b/discord/client.py
index 1100c83..b9abe1e 100644
--- a/discord/client.py
+++ b/discord/client.py
@@ -1,16 +1,33 @@
 import asyncio
+from enum import IntEnum
 import json
 import sys
 import threading
 import warnings
 from typing import Optional, Coroutine, Any, Callable
+import zlib
+import aiohttp
 import websockets
 
 from .utils import EventEmitter
-from .utils.rest import get
+from .utils.rest import RESTClient
 from .intents import Intents, get_number
 from .user import User
 
+class GatewayEvents(IntEnum):
+    DISPATCH           = 0
+    HEARTBEAT          = 1
+    IDENTIFY           = 2
+    PRESENCE           = 3
+    VOICE_STATE        = 4
+    VOICE_PING         = 5
+    RESUME             = 6
+    RECONNECT          = 7
+    REQUEST_MEMBERS    = 8
+    INVALIDATE_SESSION = 9
+    HELLO              = 10
+    HEARTBEAT_ACK      = 11
+    GUILD_SYNC         = 12
 
 class Client:
     """
@@ -22,7 +39,7 @@ class Client:
     @property
     async def user(self):
         """The `discord.user.User` associated with the client."""
-        data = await get(self._token, '/users/@me')
+        data = await self.rest_client.get(self._token, '/users/@me')
         return User(data)
 
     def __init__(self, intents: list[Intents]):
@@ -33,11 +50,19 @@ class Client:
                           "Discord developer portal.")
         if Intents.GUILD_MEMBERS in intents or Intents.GUILD_PRESENCES in intents:
             warnings.warn("You are using one or more privileged intent (Guild Members and/or Guild Presences). You "
-                          "must enable them in the Discord developer portal.") 
-        self.code = get_number(intents)
+                          "must enable them in the Discord developer portal.")
+        self.code: int = get_number(intents)
         self.event_emitter = EventEmitter()
-
-    async def connect(self, token: str, intent_code: int):
+        self.buffer = bytearray()
+        self.inflator = zlib.decompressobj()
+        self.heartbeat_interval: int = None
+        self.ready: bool = False
+        self.rest_client = RESTClient(self._token, aiohttp.ClientSession(headers={
+            "Authorization": f"Bot {self._token}",
+            "User-Agent": "DiscordBot (https://github.com/mounderfod/discobra 0.0.1)"
+        }))
+
+    async def connect(self):
         """
         Connects to the Discord gateway and begins sending heartbeats.
         This should not be called manually.
@@ -50,23 +75,8 @@ class Client:
             hello = await gateway.recv()
             self.gateway = gateway
             threading.Thread(target=self.loop.run_forever).start()
-            heartbeat = asyncio.run_coroutine_threadsafe(
-                self.heartbeat(gateway, json.loads(hello)['d']['heartbeat_interval']), self.loop)
-            identify = {
-                "op": 2,
-                "d": {
-                    "token": token,
-                    "intents": intent_code,
-                    "properties": {
-                        "os": sys.platform,
-                        "browser": "discobra",
-                        "device": "discobra"
-                    }
-                }
-            }
-            await gateway.send(json.dumps(identify))
-            ready = await gateway.recv()
-            self.event_emitter.emit('on_ready')
+            while True:
+                await self.poll_event()
     
     async def send(self, data: dict):
         """
@@ -81,7 +91,41 @@ class Client:
         """
         Receive data from the gateway.
         """
-        pass
+        if type(msg) is bytes:
+            self.buffer.extend(msg)
+            if len(msg) < 4 or msg[-4:] != b'\x00\x00\xff\xff':
+                return
+
+            msg = self.inflator.decompress(self.buffer)
+            msg.decode('utf-8')
+            self.buffer = bytearray()
+        msg = json.loads(msg)
+        opcode = msg['op']
+        data = msg['d']
+        sequence = msg['s']
+
+        if opcode != GatewayEvents.DISPATCH.value:
+            if opcode == GatewayEvents.RECONNECT.value:
+                return await self.close()
+
+            if opcode == GatewayEvents.HELLO.value:
+                self.heartbeat_interval = data['heartbeat_interval']
+                asyncio.run_coroutine_threadsafe(self.heartbeat(self.heartbeat_interval), self.loop)
+                return await self.identify()
+
+            if opcode == GatewayEvents.HEARTBEAT_ACK.value:
+                return await self.heartbeat(self.heartbeat_interval)
+
+            if opcode == GatewayEvents.HEARTBEAT.value:
+                return await self.heartbeat(self.heartbeat_interval)
+
+        event = msg['t']
+
+        if event == 'READY':
+            print(data)
+
+        self.event_emitter.emit('on_' + event.lower())
+
 
     async def close(self):
         """
@@ -90,10 +134,12 @@ class Client:
         self.loop.stop()
         await self.gateway.close()
 
-    async def poll_events(self):
-        pass
+    async def poll_event(self):
+        msg = await self.gateway.recv()
+        await self.recv(msg)
+
 
-    async def heartbeat(self, gateway: websockets.WebSocketClientProtocol, interval: int):
+    async def heartbeat(self, interval: int):
         """
         Sends a heartbeat through the gateway to keep the connection active.
         This should not be called manually.
@@ -102,14 +148,27 @@ class Client:
         - gateway: The gateway to keep open.
         - interval: How often to send a heartbeat. This is given by the gateway in a Hello packet.
         """
-        while True:
-            await asyncio.sleep(interval / 1000)
-            heartbeat = {
-                "op": 1,
-                "d": None
-            }
-            await gateway.send(json.dumps(heartbeat))
-            ack = await gateway.recv()
+        await asyncio.sleep(interval / 1000)
+        heartbeat = {
+            "op": 1,
+            "d": None
+        }
+        await self.gateway.send(json.dumps(heartbeat))
+
+    async def identify(self):
+        """
+        Identify the client.
+        """
+        identify = {
+            "op": GatewayEvents.IDENTIFY,
+            "d": {
+                "token": self._token,
+                "intents": self.code,
+                "properties": {
+                    "os": sys.platform,
+                    "browser": "discobra",
+                    "device": "discobra"
+                }
 
     def event(self, coro: Optional[Callable[..., Coroutine[Any, Any, Any]]]=None, /) -> Optional[Callable[..., Coroutine[Any, Any, Any]]]:
         """
@@ -131,4 +190,4 @@ class Client:
         - token: Your bot token. Do not share this with anyone!
         """
         self._token = token
-        asyncio.run(self.connect(token, self.code))
+        asyncio.run(self.connect())
diff --git a/discord/utils/rest.py b/discord/utils/rest.py
index f7f6ab8..919d54d 100644
--- a/discord/utils/rest.py
+++ b/discord/utils/rest.py
@@ -1,15 +1,14 @@
 import aiohttp
-import asyncio
 
 from discord.utils.exceptions import APIException
 
+class RESTClient:
+    def __init__(self, token: str, session: aiohttp.ClientSession):
+        self.token = token
+        self.session = session
 
-async def get(token, url):
-    async with aiohttp.ClientSession(headers={
-        "Authorization": f"Bot {token}",
-        "User-Agent": f"DiscordBot (https://github.com/mounderfod/discobra 0.0.1)"
-    }) as session:
-        async with session.get(url='https://discord.com/api/v10' + url) as r:
+    async def get(self, url: str):
+        async with self.session.get(url='https://discord.com/api/v10' + url) as r:
             data = await r.json()
             match r.status:
                 case 200:
@@ -18,12 +17,8 @@ async def get(token, url):
                     raise APIException(data['message'])
 
 
-async def post(token, url, data):
-    async with aiohttp.ClientSession(headers={
-        "Authorization": f"Bot {token}",
-        "User-Agent": f"DiscordBot (https://github.com/mounderfod/discobra 0.0.1)"
-    }) as session:
-        async with session.post(url='https://discord.com/api/v10' + url, data=data) as r:
+    async def post(self, url: str, data):
+        async with self.session.post(url='https://discord.com/api/v10' + url, data=data) as r:
             data = await r.json()
             match r.status:
                 case 200 | 204:
@@ -32,26 +27,18 @@ async def post(token, url, data):
                     raise APIException(data['message'])
 
 
-async def patch(token, url, data):
-    async with aiohttp.ClientSession(headers={
-        "Authorization": f"Bot {token}",
-        "User-Agent": f"DiscordBot (https://github.com/mounderfod/discobra 0.0.1)"
-    }) as session:
-        async with session.patch(url='https://discord.com/api/v10' + url, data=data) as r:
-            data = await r.json()
-            match r.status:
+    async def patch(self, url, data):
+        async with self.session.patch(url='https://discord.com/api/v10' + url, data=data) as res:
+            data = await res.json()
+            match res.status:
                 case 200 | 204:
                     return data
                 case other:
                     raise APIException(data['message'])
 
 
-async def delete(token, url):
-    async with aiohttp.ClientSession(headers={
-        "Authorization": f"Bot {token}",
-        "User-Agent": f"DiscordBot (https://github.com/mounderfod/discobra 0.0.1)"
-    }) as session:
-        async with session.delete(url='https://discord.com/api/v10' + url) as r:
+    async def delete(self, url):
+        async with self.session.delete(url='https://discord.com/api/v10' + url) as r:
             data = await r.json()
             match r.status:
                 case 200: