Source code for discord.api.gateway

# -*- coding: utf-8 -*-
# cython: language_level=3
# Copyright (c) 2021-present VincentRPS

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE
"""Implementation of the Discord Gateway."""

from __future__ import annotations

import asyncio
import json
import logging
import platform
import zlib
from random import random
from time import perf_counter, time
from typing import Any, Coroutine, List

import aiohttp

from discord import utils
from discord.events import catalog
from discord.snowflake import Snowflakeish
from discord.types.dict import Dict

from ..http import RESTFactory
from ..internal.dispatcher import Dispatcher
from ..state import ConnectionState

ZLIB_SUFFIX = b'\x00\x00\xff\xff'
_log = logging.getLogger(__name__)
url_extension = '?v=10&encoding=json&compress=zlib-stream'


[docs]class Shard: """Represents a Discord Shard. Parameters ---------- state The :class:`ConnectionState` cache to use. dispatcher The :class:`Dispatcher` to use. shard_id: :class:`int` The shard number. url: :class:`str` The URL to use. mobile :class:`bool` If to have a mobile presence or not. Attributes ---------- buffer An array of bytes which buffers the connection. remaining The amount remaining per The times per window The window max The max inflator The inflator for this shard. """ def __init__( self, state: ConnectionState, dispatcher: Dispatcher, shard_id: int, url: str, mobile: bool = False, ): self.remaining = 110 self.per = 60.0 self.window = 0.0 self.max = 110 self.state = state self.url = url self.mobile = mobile self.dis = dispatcher self.inflator = zlib.decompressobj() self.shard_id = shard_id self.buffer = bytearray() self.last_recv = perf_counter() self.last_send = perf_counter() self.last_ack = perf_counter() self._session_id: int = None self._ratelimit_lock: asyncio.Lock = asyncio.Lock() self.ws: aiohttp.ClientWebSocketResponse = None self.latency: float = float('nan') self.ready: asyncio.Event = asyncio.Event() self.state.loop.create_task(self.enter()) async def enter(self): self._session = aiohttp.ClientSession()
[docs] async def connect(self, token: str) -> None: """Connects to the url specified, with the token Parameters ---------- token Your bot token. Attributes ---------- _session The aiohttp ClientSession ws The WebSocket connection. """ self.ws = await self._session.ws_connect(self.url) self.token = token if self._session_id is None: await self.identify() self.state.loop.create_task(self.recv()) self.state.loop.create_task(self.check_connection()) else: await self.resume() _log.debug('Reconnected to the Gateway')
@property def is_ratelimited(self) -> bool: """Returns a True if this shard is ratelimited and False if it isn't """ now = time() if now > self.window + self.per: return False return self.remaining == 0
[docs] def delay(self) -> float: """A float showing how long we should delay until retrying.""" now = time() if now > self.window + self.per: self.remaining = self.max if self.remaining == self.max: self.window = now if self.remaining == 0: return self.per - (now - self.window) self.remaining -= 1 if self.remaining == 0: self.window = now return 0.0
[docs] async def block(self) -> None: """A function to block the connection tempoarily""" async with self._ratelimit_lock: delay = self.delay() if delay: _log.warning( 'Shard %s was ratelimited, waiting %.2f seconds.', self.shard_id, delay, ) await asyncio.sleep(delay)
async def check_connection(self): await asyncio.sleep(20) if self.last_recv + 60.0 < perf_counter(): _log.warning( f'Shard {self.shard_id} has stopped receiving from the gateway, reconnecting' ) await self.ws.close(code=4000) await self.closed(4000) elif self.latency > 10: _log.warning(f'Shard {self.shard_id} is behind by {self.latency}') self.state.loop.create_task(self.check_connection())
[docs] async def send(self, data: Dict) -> None: """Send a request to the Gateway via the shard Parameters ---------- data :class:`Dict` The data to send """ self.last_send = perf_counter() _log.debug('< %s', data) raw_payload = json.dumps(data) if isinstance(raw_payload, str): payload = raw_payload.encode('utf-8') await self.ws.send_bytes(payload)
async def recv(self) -> None: async for msg in self.ws: if msg.type == aiohttp.WSMsgType.BINARY: self.latency = self.last_ack - self.last_send self.buffer.extend(msg.data) try: raw = self.inflator.decompress(self.buffer).decode('utf-8') except: # probably corrupted data self.buffer = bytearray() return if len(msg.data) < 4 or msg.data[-4:] != ZLIB_SUFFIX: raise RuntimeError self.buffer = bytearray() # clean buffer data = json.loads(raw) _log.debug('> %s', data) self._seq = data['s'] self.dis.dispatch('RAW_SOCKET_RECEIVE') self.last_recv = perf_counter() if data['op'] == 0: if ( data['t'] == 'READY' ): # only fire up getting the session_id on a ready event. await self._ready(data) self.dis.dispatch('READY') else: catalog.Cataloger(data, self.dis, self.state) elif data['op'] == 9: await self.ws.close(code=4000) await self.closed(4000) elif data['op'] == 10: await self.hello(data) elif data['op'] == 11: self.last_ack = perf_counter() _log.debug('> %s', data) else: self.dis.dispatch(data['op'], data) else: raise RuntimeError code = self.ws.close_code if code is None: return else: await self.closed(code) async def closed(self, code: int) -> None: _log.error(f'Shard {self.shard_id} closed with code {code}') if code == 4000: pass elif code == 4001: pass elif code == 4002: # just ignore it. pass elif code == 4003: # something weird happened, just ignore it. pass elif code == 4004: raise elif code == 4005: # pass! pass elif code == 4007: # pass and make a new session pass elif code == 4008: # retry! pass elif code == 4009: # try to resume. await self.resume() return elif code == 4010: raise # this doesn't really happen, unless you are dumb elif code == 4011: # just tell the owner to shard! raise elif code == 4012: # most likely a old lib version. raise elif code == 4013: raise elif code == 4014: raise await self.connect(token=self.token) async def heartbeat(self, interval: float) -> None: if self.is_ratelimited: await self.block() if not self.ws.closed: await self.send({'op': 1, 'd': self._seq}) await asyncio.sleep(interval) self.state.loop.create_task(self.heartbeat(interval)) async def close(self, code: int = 4000) -> None: if self.ws: await self.ws.close(code=code) self.buffer.clear() async def hello(self, data: Dict) -> None: interval = data['d']['heartbeat_interval'] / 1000 init = interval * random() await asyncio.sleep(init) self.state.loop.create_task(self.heartbeat(interval)) async def _ready(self, data: Dict) -> None: self._session_id = data['d']['session_id'] self.state._ready.set() self.ready.set() async def identify(self) -> None: await self.send( { 'op': 2, 'd': { 'token': self.token, 'intents': self.state._bot_intents, 'properties': { '$os': platform.system(), '$browser': 'discord.io' if self.mobile is False else 'Discord iOS', '$device': 'discord.io', }, 'shard': (self.shard_id, self.state.shard_count), 'v': 9, 'compress': True, }, } ) async def resume(self) -> None: await self.send( { 'op': 6, 'd': { 'token': self.token, 'session_id': self._session_id, 'seq': self._seq, }, } )
[docs]class Gateway: """Represents a Gateway connection with Discord. Parameters ---------- state The :class:`ConnectionState` to use. dispatcher The :class:`Dispatcher` to use. factory The :class:`RESTFactory` to use. mobile If your bot should have a mobile presence or not. """ def __init__( self, state: ConnectionState, dispatcher: Dispatcher, factory: RESTFactory, mobile: bool = False, ): self._s = state self.mobile = mobile self.count = self._s.shard_count self._d = dispatcher self._f = factory self.shards: List[Shard] = []
[docs] @utils.copy_doc(Shard.connect) async def connect(self, token: str) -> None: r = await self._f.get_gateway_bot() if self.count is None: shds = r['shards'] self._s.shard_count = shds else: shds = self.count for shard in range(shds): s = Shard( self._s, self._d, shard, mobile=self.mobile, url=r['url'] + url_extension, ) self._s.loop.create_task(s.connect(token)) while not s.ready.is_set(): await s.ready.wait() self.shards.append(s) _log.info('Shard %s has connected to Discord', shard)
[docs] @utils.copy_doc(Shard.send) async def send(self, payload: Dict, shard=None) -> Coroutine[Any, Any, None]: if shard == None: for s in self.shards: await s.send(payload) else: await self.shards[shard].send(payload)
async def _chunk_members(self): await asyncio.sleep(20) for guild in self._s.guilds._cache.values(): shard_id = (int(guild['id']) >> 22) % self._s.shard_count await self.shards[shard_id].send( {'op': 8, 'd': {'guild_id': guild['id'], 'query': '', 'limit': 0}} ) async def voice_state( self, guild: int, channel: Snowflakeish, mute: bool, deaf: bool ) -> Coroutine[Any, Any, None]: json = { 'op': 4, 'd': { 'guild_id': guild, 'channel_id': channel, 'self_mute': mute, 'self_deaf': deaf, }, } shard_id = (int(guild) >> 22) % self._s.shard_count await self.shards[shard_id].send(json) @property def latency(self) -> float: lat: float = float(0) for shard in self.shards: lat += shard.latency return lat