diff options
author | mjk134 <57556877+mjk134@users.noreply.github.com> | 2022-07-08 20:19:55 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-07-08 20:19:55 +0000 |
commit | 99da423d99378eaf446031451f9f6de76010d7c0 (patch) | |
tree | 835386993c02736cf5ea345722a88699e8227686 /discord | |
parent | 51304896fc682787670506b44054150054dfd71e (diff) | |
download | discobra-99da423d99378eaf446031451f9f6de76010d7c0.tar.gz |
- Add `run` to `Client` to start async loop
- Add event decorator with a `on_ready` event call
Diffstat (limited to 'discord')
-rw-r--r-- | discord/client.py | 29 |
1 files changed, 19 insertions, 10 deletions
diff --git a/discord/client.py b/discord/client.py index 73fe0e3..c500d1c 100644 --- a/discord/client.py +++ b/discord/client.py @@ -2,27 +2,23 @@ import asyncio import json import sys import threading - import websockets - +from typing import Coroutine from discord.intents import Intents, gen_number -loop = asyncio.get_event_loop() - - class Client: - def __init__(self, token: str, intents: list[Intents]): + def __init__(self, intents: list[Intents]): self.gateway = None - code = gen_number(intents) - asyncio.run(self.connect(token, code)) + self.loop = asyncio.get_event_loop() + self.code = gen_number(intents) async def connect(self, token: str, intent_code: int): async with websockets.connect("wss://gateway.discord.gg/?v=10&encoding=json") as gateway: hello = await gateway.recv() self.gateway = gateway - threading.Thread(target=loop.run_forever).start() + threading.Thread(target=self.loop.run_forever).start() heartbeat = asyncio.run_coroutine_threadsafe( - self.heartbeat(gateway, json.loads(hello)['d']['heartbeat_interval']), loop) + self.heartbeat(gateway, json.loads(hello)['d']['heartbeat_interval']), self.loop) identify = { "op": 2, "d": { @@ -37,6 +33,8 @@ class Client: } await gateway.send(json.dumps(identify)) ready = await gateway.recv() + if (hasattr(self, 'on_ready')): + await getattr(self, 'on_ready')() async def heartbeat(self, gateway: websockets.WebSocketClientProtocol, interval: int): while True: @@ -47,3 +45,14 @@ class Client: } await gateway.send(json.dumps(heartbeat)) ack = await gateway.recv() + + def event(self, coro: Coroutine, /) -> Coroutine: + if not asyncio.iscoroutinefunction(coro): + raise TypeError('event registered must be a coroutine function') + + setattr(self, coro.__name__, coro) + return coro + + def run(self, token: str): + self.token = token + asyncio.run(self.connect(self.token, self.code)) |