diff options
author | mjk134 <57556877+mjk134@users.noreply.github.com> | 2022-07-12 08:40:43 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-07-12 08:40:43 +0000 |
commit | 3c61b5814f2f7636e269624c72e876d48a8f4eef (patch) | |
tree | 7f564a8bcdb9256cc2aef3e70f18319407db36ad | |
parent | de4b806a7e5b823680acae4627607198cf1f59a6 (diff) | |
download | discobra-3c61b5814f2f7636e269624c72e876d48a8f4eef.tar.gz |
feat(client): Switch to aiohttp for webscokets connections
-rw-r--r-- | discord/client.py | 24 | ||||
-rw-r--r-- | discord/utils/rest.py | 17 |
2 files changed, 28 insertions, 13 deletions
diff --git a/discord/client.py b/discord/client.py index 5f5c3b0..e028a1a 100644 --- a/discord/client.py +++ b/discord/client.py @@ -50,6 +50,7 @@ class Client: You need to initialise one of these and then use `run()` with a token to login. """ _token: str + rest_client: RESTClient @property async def user(self): @@ -72,10 +73,6 @@ class Client: self.inflator = zlib.decompressobj() self.heartbeat_interval: int = None self.ready: bool = False - self.rest_client = RESTClient(self._token, aiohttp.ClientSession(headers={ - "Authorization": f"Bot {self._token}", - "User-Agent": "DiscordBot (https://github.com/mounderfod/discobra 0.0.1)" - })) async def connect(self): """ @@ -86,7 +83,12 @@ class Client: - token: Your bot token. - intent_code: The number which represents the `discord.intents.Intents` being used. """ - async with websockets.connect("wss://gateway.discord.gg/?v=10&encoding=json") as gateway: + timeout = aiohttp.ClientTimeout(total=60) + self.rest_client = RESTClient(self._token, aiohttp.ClientSession(headers={ + "Authorization": f"Bot {self._token}", + "User-Agent": "DiscordBot (https://github.com/mounderfod/discobra 0.0.1)" + }, timeout=timeout)) + async with self.rest_client.session.ws_connect("wss://gateway.discord.gg/?v=10&encoding=json") as gateway: self.gateway = gateway threading.Thread(target=self.loop.run_forever).start() while True: @@ -99,7 +101,7 @@ class Client: **Parameters:** - data: The data to send to the gateway. """ - await self.gateway.send(json.dumps(data)) + await self.gateway.send_str(json.dumps(data)) async def recv(self, msg): """ @@ -148,8 +150,13 @@ class Client: await self.gateway.close() async def poll_event(self): - msg = await self.gateway.recv() - await self.recv(msg) + async for msg in self.gateway: + if msg.type == aiohttp.WSMsgType.TEXT: + await self.recv(msg.data) + elif msg.type == aiohttp.WSMsgType.CLOSED: + break + elif msg.type == aiohttp.WSMsgType.ERROR: + break async def heartbeat(self, interval: int): """ @@ -206,4 +213,5 @@ class Client: - token: Your bot token. Do not share this with anyone! """ self._token = token + asyncio.run(self.connect()) diff --git a/discord/utils/rest.py b/discord/utils/rest.py index 652de5d..fc3f729 100644 --- a/discord/utils/rest.py +++ b/discord/utils/rest.py @@ -11,7 +11,14 @@ class RESTClient: """ def __init__(self, token: str, session: aiohttp.ClientSession): self.token = token - self.session = session + self._session = session + + @property + def session(self) -> aiohttp.ClientSession: + """ + Returns the _session used by the client. + """ + return self._session async def get(self, url: str): """ @@ -20,7 +27,7 @@ class RESTClient: **Parameters:** - url: The part of the request URL that goes after `https://discord.com/api/v10` """ - async with self.session.get(url='https://discord.com/api/v10' + url) as r: + async with self._session.get(url='https://discord.com/api/v10' + url) as r: data = await r.json() match r.status: case 200: @@ -36,7 +43,7 @@ class RESTClient: - url: The part of the request URL that goes after `https://discord.com/api/v10` - data: The data to post. """ - async with self.session.post(url='https://discord.com/api/v10' + url, data=data) as r: + async with self._session.post(url='https://discord.com/api/v10' + url, data=data) as r: data = await r.json() match r.status: case 200 | 204 | 201: @@ -52,7 +59,7 @@ class RESTClient: - url: The part of the request URL that goes after `https://discord.com/api/v10` - data: The data to patch. """ - async with self.session.patch(url='https://discord.com/api/v10' + url, data=data) as res: + async with self._session.patch(url='https://discord.com/api/v10' + url, data=data) as res: data = await res.json() match res.status: case 200 | 204: @@ -67,7 +74,7 @@ class RESTClient: **Parameters:** - url: The part of the request URL that goes after `https://discord.com/api/v10` """ - async with self.session.delete(url='https://discord.com/api/v10' + url) as r: + async with self._session.delete(url='https://discord.com/api/v10' + url) as r: data = await r.json() match r.status: case 200: |