diff options
author | Noah <mounderfod@gmail.com> | 2022-07-09 18:59:36 +0100 |
---|---|---|
committer | Noah <mounderfod@gmail.com> | 2022-07-09 18:59:36 +0100 |
commit | 57fff05c0e2058855946ffa4e5c02872c534ee65 (patch) | |
tree | 41faac71321b81f0ac34083f20e57ce461ba46f6 /discord | |
parent | 5058a47b28fa1f0a5982a65c2a65955ea158a0e4 (diff) | |
parent | 55fe04a1d73625600c87e41fe45e5a713bda0aba (diff) | |
download | discobra-57fff05c0e2058855946ffa4e5c02872c534ee65.tar.gz |
Merge branch 'master' of https://github.com/mounderfod/discobra
Conflicts: discord/client.py
Diffstat (limited to 'discord')
-rw-r--r-- | discord/client.py | 131 | ||||
-rw-r--r-- | discord/utils/rest.py | 41 |
2 files changed, 109 insertions, 63 deletions
diff --git a/discord/client.py b/discord/client.py index 1100c83..b9abe1e 100644 --- a/discord/client.py +++ b/discord/client.py @@ -1,16 +1,33 @@ import asyncio +from enum import IntEnum import json import sys import threading import warnings from typing import Optional, Coroutine, Any, Callable +import zlib +import aiohttp import websockets from .utils import EventEmitter -from .utils.rest import get +from .utils.rest import RESTClient from .intents import Intents, get_number from .user import User +class GatewayEvents(IntEnum): + DISPATCH = 0 + HEARTBEAT = 1 + IDENTIFY = 2 + PRESENCE = 3 + VOICE_STATE = 4 + VOICE_PING = 5 + RESUME = 6 + RECONNECT = 7 + REQUEST_MEMBERS = 8 + INVALIDATE_SESSION = 9 + HELLO = 10 + HEARTBEAT_ACK = 11 + GUILD_SYNC = 12 class Client: """ @@ -22,7 +39,7 @@ class Client: @property async def user(self): """The `discord.user.User` associated with the client.""" - data = await get(self._token, '/users/@me') + data = await self.rest_client.get(self._token, '/users/@me') return User(data) def __init__(self, intents: list[Intents]): @@ -33,11 +50,19 @@ class Client: "Discord developer portal.") if Intents.GUILD_MEMBERS in intents or Intents.GUILD_PRESENCES in intents: warnings.warn("You are using one or more privileged intent (Guild Members and/or Guild Presences). You " - "must enable them in the Discord developer portal.") - self.code = get_number(intents) + "must enable them in the Discord developer portal.") + self.code: int = get_number(intents) self.event_emitter = EventEmitter() - - async def connect(self, token: str, intent_code: int): + self.buffer = bytearray() + self.inflator = zlib.decompressobj() + self.heartbeat_interval: int = None + self.ready: bool = False + self.rest_client = RESTClient(self._token, aiohttp.ClientSession(headers={ + "Authorization": f"Bot {self._token}", + "User-Agent": "DiscordBot (https://github.com/mounderfod/discobra 0.0.1)" + })) + + async def connect(self): """ Connects to the Discord gateway and begins sending heartbeats. This should not be called manually. @@ -50,23 +75,8 @@ class Client: hello = await gateway.recv() self.gateway = gateway threading.Thread(target=self.loop.run_forever).start() - heartbeat = asyncio.run_coroutine_threadsafe( - self.heartbeat(gateway, json.loads(hello)['d']['heartbeat_interval']), self.loop) - identify = { - "op": 2, - "d": { - "token": token, - "intents": intent_code, - "properties": { - "os": sys.platform, - "browser": "discobra", - "device": "discobra" - } - } - } - await gateway.send(json.dumps(identify)) - ready = await gateway.recv() - self.event_emitter.emit('on_ready') + while True: + await self.poll_event() async def send(self, data: dict): """ @@ -81,7 +91,41 @@ class Client: """ Receive data from the gateway. """ - pass + if type(msg) is bytes: + self.buffer.extend(msg) + if len(msg) < 4 or msg[-4:] != b'\x00\x00\xff\xff': + return + + msg = self.inflator.decompress(self.buffer) + msg.decode('utf-8') + self.buffer = bytearray() + msg = json.loads(msg) + opcode = msg['op'] + data = msg['d'] + sequence = msg['s'] + + if opcode != GatewayEvents.DISPATCH.value: + if opcode == GatewayEvents.RECONNECT.value: + return await self.close() + + if opcode == GatewayEvents.HELLO.value: + self.heartbeat_interval = data['heartbeat_interval'] + asyncio.run_coroutine_threadsafe(self.heartbeat(self.heartbeat_interval), self.loop) + return await self.identify() + + if opcode == GatewayEvents.HEARTBEAT_ACK.value: + return await self.heartbeat(self.heartbeat_interval) + + if opcode == GatewayEvents.HEARTBEAT.value: + return await self.heartbeat(self.heartbeat_interval) + + event = msg['t'] + + if event == 'READY': + print(data) + + self.event_emitter.emit('on_' + event.lower()) + async def close(self): """ @@ -90,10 +134,12 @@ class Client: self.loop.stop() await self.gateway.close() - async def poll_events(self): - pass + async def poll_event(self): + msg = await self.gateway.recv() + await self.recv(msg) + - async def heartbeat(self, gateway: websockets.WebSocketClientProtocol, interval: int): + async def heartbeat(self, interval: int): """ Sends a heartbeat through the gateway to keep the connection active. This should not be called manually. @@ -102,14 +148,27 @@ class Client: - gateway: The gateway to keep open. - interval: How often to send a heartbeat. This is given by the gateway in a Hello packet. """ - while True: - await asyncio.sleep(interval / 1000) - heartbeat = { - "op": 1, - "d": None - } - await gateway.send(json.dumps(heartbeat)) - ack = await gateway.recv() + await asyncio.sleep(interval / 1000) + heartbeat = { + "op": 1, + "d": None + } + await self.gateway.send(json.dumps(heartbeat)) + + async def identify(self): + """ + Identify the client. + """ + identify = { + "op": GatewayEvents.IDENTIFY, + "d": { + "token": self._token, + "intents": self.code, + "properties": { + "os": sys.platform, + "browser": "discobra", + "device": "discobra" + } def event(self, coro: Optional[Callable[..., Coroutine[Any, Any, Any]]]=None, /) -> Optional[Callable[..., Coroutine[Any, Any, Any]]]: """ @@ -131,4 +190,4 @@ class Client: - token: Your bot token. Do not share this with anyone! """ self._token = token - asyncio.run(self.connect(token, self.code)) + asyncio.run(self.connect()) diff --git a/discord/utils/rest.py b/discord/utils/rest.py index f7f6ab8..919d54d 100644 --- a/discord/utils/rest.py +++ b/discord/utils/rest.py @@ -1,15 +1,14 @@ import aiohttp -import asyncio from discord.utils.exceptions import APIException +class RESTClient: + def __init__(self, token: str, session: aiohttp.ClientSession): + self.token = token + self.session = session -async def get(token, url): - async with aiohttp.ClientSession(headers={ - "Authorization": f"Bot {token}", - "User-Agent": f"DiscordBot (https://github.com/mounderfod/discobra 0.0.1)" - }) as session: - async with session.get(url='https://discord.com/api/v10' + url) as r: + async def get(self, url: str): + async with self.session.get(url='https://discord.com/api/v10' + url) as r: data = await r.json() match r.status: case 200: @@ -18,12 +17,8 @@ async def get(token, url): raise APIException(data['message']) -async def post(token, url, data): - async with aiohttp.ClientSession(headers={ - "Authorization": f"Bot {token}", - "User-Agent": f"DiscordBot (https://github.com/mounderfod/discobra 0.0.1)" - }) as session: - async with session.post(url='https://discord.com/api/v10' + url, data=data) as r: + async def post(self, url: str, data): + async with self.session.post(url='https://discord.com/api/v10' + url, data=data) as r: data = await r.json() match r.status: case 200 | 204: @@ -32,26 +27,18 @@ async def post(token, url, data): raise APIException(data['message']) -async def patch(token, url, data): - async with aiohttp.ClientSession(headers={ - "Authorization": f"Bot {token}", - "User-Agent": f"DiscordBot (https://github.com/mounderfod/discobra 0.0.1)" - }) as session: - async with session.patch(url='https://discord.com/api/v10' + url, data=data) as r: - data = await r.json() - match r.status: + async def patch(self, url, data): + async with self.session.patch(url='https://discord.com/api/v10' + url, data=data) as res: + data = await res.json() + match res.status: case 200 | 204: return data case other: raise APIException(data['message']) -async def delete(token, url): - async with aiohttp.ClientSession(headers={ - "Authorization": f"Bot {token}", - "User-Agent": f"DiscordBot (https://github.com/mounderfod/discobra 0.0.1)" - }) as session: - async with session.delete(url='https://discord.com/api/v10' + url) as r: + async def delete(self, url): + async with self.session.delete(url='https://discord.com/api/v10' + url) as r: data = await r.json() match r.status: case 200: |