Source code for discord_http.gateway.client


import logging
import asyncio

from datetime import datetime, UTC
from typing import Coroutine, TYPE_CHECKING

from .shard import Shard
from .object import PlayingStatus

if TYPE_CHECKING:
    from ..object import Snowflake
    from ..client import GatewayCacheFlags, Client, Intents

_log = logging.getLogger("discord_http")

__all__ = (
    "GatewayClient",
)


[docs] class GatewayClient: def __init__( self, bot: "Client", *, cache_flags: "GatewayCacheFlags | None" = None, intents: "Intents | None" = None, automatic_shards: bool = True, shard_id: int | None = None, shard_count: int = 1, shard_ids: list[int] | None = None, max_concurrency: int | None = None ): self.bot = bot self.intents = intents self.cache_flags = cache_flags self.automatic_shards = automatic_shards self.shard_id = shard_id self.shard_count = shard_count self.shard_ids = shard_ids self.max_concurrency = max_concurrency self.__shards: dict[int, Shard] = {} self.bot.backend.add_url_rule( "/shards", "shards", self._index_websocket_status, # type: ignore methods=["GET"] )
[docs] def get_shard(self, shard_id: int) -> Shard | None: """ Returns the shard object of the shard with the specified ID. Parameters ---------- shard_id: `int` The ID of the shard to get. Returns ------- `Optional[Shard]` The shard object with the specified ID, or `None` if not found. """ return self.__shards.get(shard_id, None)
[docs] async def change_presence(self, status: PlayingStatus) -> None: """ Changes the presence of all shards to the specified status. Parameters ---------- status: `PlayingStatus` The status to change to. """ for shard in self.__shards.values(): await shard.change_presence(status)
async def _index_websocket_status(self) -> dict[int, dict]: _now = datetime.now(UTC) return { shard_id: { "ping": shard.status.ping, "latency": shard.status.latency, "activity": { "last": str(shard._last_activity), "between": str(_now - shard._last_activity) } } for shard_id, shard in sorted( self.__shards.items(), key=lambda x: x[0] ) } async def _fetch_gateway(self) -> tuple[int, int]: r = await self.bot.state.query("GET", "/gateway/bot") return ( r.response["shards"], r.response["session_start_limit"]["max_concurrency"] ) async def _launch_shard(self, shard_id: int) -> None: """ Individual shard launching Parameters ---------- shard_id: `int` The shard ID to launch """ try: shard = Shard( bot=self.bot, intents=self.intents, cache_flags=self.cache_flags, shard_id=shard_id, shard_count=self.shard_count, api_version=self.bot.api_version, debug_events=self.bot.debug_events ) shard.connect() while not shard.status.session_id: await asyncio.sleep(0.5) except Exception as e: _log.error("Error launching shard, trying again...", exc_info=e) return await self._launch_shard(shard_id) self.__shards[shard_id] = shard
[docs] def shard_by_guild_id(self, guild_id: "Snowflake | int") -> int: """ Returns the shard ID of the shard that the guild is in Parameters ---------- guild_id: `Snowflake | int` The ID of the guild to get the shard ID of Returns ------- `int` The shard ID of the guild """ return (int(guild_id) >> 22) % self.shard_count
async def _launch_all_shards(self) -> None: """ Launches all the shards """ if self.automatic_shards: self.shard_count, self.max_concurrency = await self._fetch_gateway() if self.shard_count == 1: # There is no need to shard if there is only 1 shard _log.debug("Sharding disabled, no point in sharding 1 shard") self.max_concurrency = None shard_ids = self.shard_ids or range(self.shard_count) if not self.max_concurrency: for shard_id in shard_ids: await self._launch_shard(shard_id) _log.debug(f"All {len(shard_ids)} shard(s) have launched") else: chunks = [ list(shard_ids[i:i + self.max_concurrency]) for i in range(0, len(shard_ids), self.max_concurrency) ] for i, shard_chunk in enumerate(chunks, start=1): _booting: list[Coroutine] = [ self._launch_shard(shard_id) for shard_id in shard_chunk ] _log.debug(f"Launching bucket {i}/{len(chunks)}") await asyncio.gather(*_booting) if i != len(chunks): _log.debug(f"Bucket {i}/{len(chunks)} shards launched, waiting (5s/bucket)") await asyncio.sleep(5) else: _log.debug(f"Bucket {i}/{len(chunks)} shards launched, last bucket, skipping wait") _log.debug(f"All {len(chunks)} bucket(s) have launched a total of {self.shard_count} shard(s)") asyncio.create_task(self._delay_full_ready()) async def _delay_full_ready(self) -> None: _waiting: list[Coroutine] = [ g.wait_until_ready() for g in self.__shards.values() ] # Gather all shards to now wait until they are ready await asyncio.gather(*_waiting) self.bot._shards_ready.set() _log.info("discord.http/gateway is now ready")
[docs] def start(self) -> None: """ Start the gateway client """ self.bot.loop.create_task(self._launch_all_shards())
[docs] async def close(self) -> None: """ Close the gateway client """ async def _close(): to_close = [ asyncio.ensure_future(shard.close(kill=True)) for shard in self.__shards.values() ] if to_close: await asyncio.wait(to_close) _task = asyncio.create_task(_close()) await _task