about summary refs log tree commit diff stats
diff options
context:
space:
mode:
authormjk134 <57556877+mjk134@users.noreply.github.com>2022-07-12 08:40:43 +0000
committerGitHub <noreply@github.com>2022-07-12 08:40:43 +0000
commit3c61b5814f2f7636e269624c72e876d48a8f4eef (patch)
tree7f564a8bcdb9256cc2aef3e70f18319407db36ad
parentde4b806a7e5b823680acae4627607198cf1f59a6 (diff)
downloaddiscobra-3c61b5814f2f7636e269624c72e876d48a8f4eef.tar.gz
feat(client): Switch to aiohttp for webscokets connections
-rw-r--r--discord/client.py24
-rw-r--r--discord/utils/rest.py17
2 files changed, 28 insertions, 13 deletions
diff --git a/discord/client.py b/discord/client.py
index 5f5c3b0..e028a1a 100644
--- a/discord/client.py
+++ b/discord/client.py
@@ -50,6 +50,7 @@ class Client:
     You need to initialise one of these and then use `run()` with a token to login.
     """
     _token: str
+    rest_client: RESTClient
 
     @property
     async def user(self):
@@ -72,10 +73,6 @@ class Client:
         self.inflator = zlib.decompressobj()
         self.heartbeat_interval: int = None
         self.ready: bool = False
-        self.rest_client = RESTClient(self._token, aiohttp.ClientSession(headers={
-            "Authorization": f"Bot {self._token}",
-            "User-Agent": "DiscordBot (https://github.com/mounderfod/discobra 0.0.1)"
-        }))
 
     async def connect(self):
         """
@@ -86,7 +83,12 @@ class Client:
         - token: Your bot token.
         - intent_code: The number which represents the `discord.intents.Intents` being used.
         """
-        async with websockets.connect("wss://gateway.discord.gg/?v=10&encoding=json") as gateway:
+        timeout = aiohttp.ClientTimeout(total=60)
+        self.rest_client = RESTClient(self._token, aiohttp.ClientSession(headers={
+            "Authorization": f"Bot {self._token}",
+            "User-Agent": "DiscordBot (https://github.com/mounderfod/discobra 0.0.1)"
+        }, timeout=timeout))
+        async with self.rest_client.session.ws_connect("wss://gateway.discord.gg/?v=10&encoding=json") as gateway:
             self.gateway = gateway
             threading.Thread(target=self.loop.run_forever).start()
             while True:
@@ -99,7 +101,7 @@ class Client:
         **Parameters:**
         - data: The data to send to the gateway.
         """
-        await self.gateway.send(json.dumps(data))
+        await self.gateway.send_str(json.dumps(data))
 
     async def recv(self, msg):
         """
@@ -148,8 +150,13 @@ class Client:
         await self.gateway.close()
 
     async def poll_event(self):
-        msg = await self.gateway.recv()
-        await self.recv(msg)
+        async for msg in self.gateway:
+            if msg.type == aiohttp.WSMsgType.TEXT:
+                await self.recv(msg.data)
+            elif msg.type == aiohttp.WSMsgType.CLOSED:
+                break
+            elif msg.type == aiohttp.WSMsgType.ERROR:
+                break
 
     async def heartbeat(self, interval: int):
         """
@@ -206,4 +213,5 @@ class Client:
         - token: Your bot token. Do not share this with anyone!
         """
         self._token = token
+
         asyncio.run(self.connect())
diff --git a/discord/utils/rest.py b/discord/utils/rest.py
index 652de5d..fc3f729 100644
--- a/discord/utils/rest.py
+++ b/discord/utils/rest.py
@@ -11,7 +11,14 @@ class RESTClient:
     """
     def __init__(self, token: str, session: aiohttp.ClientSession):
         self.token = token
-        self.session = session
+        self._session = session
+
+    @property
+    def session(self) -> aiohttp.ClientSession:
+        """
+        Returns the _session used by the client.
+        """
+        return self._session
 
     async def get(self, url: str):
         """
@@ -20,7 +27,7 @@ class RESTClient:
         **Parameters:**
         - url: The part of the request URL that goes after `https://discord.com/api/v10`
         """
-        async with self.session.get(url='https://discord.com/api/v10' + url) as r:
+        async with self._session.get(url='https://discord.com/api/v10' + url) as r:
             data = await r.json()
             match r.status:
                 case 200:
@@ -36,7 +43,7 @@ class RESTClient:
         - url: The part of the request URL that goes after `https://discord.com/api/v10`
         - data: The data to post.
         """
-        async with self.session.post(url='https://discord.com/api/v10' + url, data=data) as r:
+        async with self._session.post(url='https://discord.com/api/v10' + url, data=data) as r:
             data = await r.json()
             match r.status:
                 case 200 | 204 | 201:
@@ -52,7 +59,7 @@ class RESTClient:
         - url: The part of the request URL that goes after `https://discord.com/api/v10`
         - data: The data to patch.
         """
-        async with self.session.patch(url='https://discord.com/api/v10' + url, data=data) as res:
+        async with self._session.patch(url='https://discord.com/api/v10' + url, data=data) as res:
             data = await res.json()
             match res.status:
                 case 200 | 204:
@@ -67,7 +74,7 @@ class RESTClient:
         **Parameters:**
         - url: The part of the request URL that goes after `https://discord.com/api/v10`
         """
-        async with self.session.delete(url='https://discord.com/api/v10' + url) as r:
+        async with self._session.delete(url='https://discord.com/api/v10' + url) as r:
             data = await r.json()
             match r.status:
                 case 200: