Source code for discord.api.rest

# -*- coding: utf-8 -*-
# cython: language_level=3
# Copyright (c) 2021-present VincentRPS

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE
from __future__ import annotations

import asyncio
import json
import logging
import typing
import weakref
from typing import Any, Iterable, Optional, Sequence, TypeVar, Union
from urllib.parse import quote

import aiohttp

from discord import utils
from discord.file import File
from discord.state import ConnectionState
from discord.types.dict import Dict

from ..internal.exceptions import Forbidden, NotFound, RESTError, ServerError

_log = logging.getLogger(__name__)

__all__: typing.List[str] = [
    'RESTClient',
]

PAD = typing.TypeVar('PAD', bound='PadLock')

aiohttp.hdrs.WEBSOCKET = 'websocket'


async def parse_tj(
    response: aiohttp.ClientResponse,
) -> typing.Union[Dict, str]:
    text = await response.text(encoding='utf-8')
    try:
        if response.headers['content-type'] == 'application/json':
            return json.loads(text)  # type: ignore
    except KeyError:
        # could be errored out
        # cause of cloudflare
        pass

    return text


class Route:
    def __init__(self, method: str, endpoint: str, **params: typing.Any):
        self.method = method
        self.endpoint = endpoint

        self.guild_id: typing.Optional[int] = params.get('guild_id')
        self.channel_id: typing.Optional[int] = params.get('channel_id')

        # Webhooks
        self.webhook_id: typing.Optional[int] = params.get('webhook_id')
        self.webhook_token: typing.Optional[str] = params.get('webhook_token')

    @property
    def bucket(self) -> str:
        return f'{self.method}:{self.endpoint}:{self.guild_id}:{self.channel_id}:{self.webhook_id}:{self.webhook_token}'  # type: ignore # noqa: ignore


class PadLock:
    # based off the PadLock of interactions.py & MaybeUnlock of discord.py.
    def __init__(self, lock: asyncio.Lock):
        self.lock: asyncio.Lock = lock
        self.MaybeUnlock: bool = True

    def __enter__(self: PAD) -> PAD:
        return self

    # defers the UnLock.
    def defer(self) -> None:
        self.MaybeUnlock = False

    def __exit__(self, exc_type, exc, traceback) -> None:
        if self.MaybeUnlock:
            self.lock.release()


[docs]class RESTClient: """Represents a Rest connection with Discord. .. versionadded:: 0.3.0 Attributes ---------- url The Discord API URL loop The current event loop or your own. connector The base aiohttp connector header The header sent to discord. """ def __init__(self, *, state=None, proxy=None, proxy_auth=None, version=10): self.user_agent = 'DiscordBot (https://github.com/VincentRPS/discord.io)' self.header: typing.Dict[str, str] = {'User-Agent': self.user_agent} self._locks: weakref.WeakValueDictionary[ str, asyncio.Lock ] = weakref.WeakValueDictionary() self._has_global: asyncio.Event = asyncio.Event() self._has_global.set() self.state = state or ConnectionState() self.proxy = proxy self.proxy_auth = proxy_auth self._session: aiohttp.ClientSession = None self.url = f'https://discord.com/api/v{version}' if version not in (8, 9, 10): raise DeprecationWarning( 'The API Version you are running has been decommissioned, please bump the version.' )
[docs] async def send( # noqa: ignore self, route: Route, files: typing.Optional[typing.Sequence[File]] = None, form: typing.Optional[typing.Iterable[Dict]] = None, **params: typing.Any, ): """Sends a request to discord .. versionadded:: 0.3.0 """ method = route.method url = self.url + route.endpoint bucket = route.bucket self._session = aiohttp.ClientSession() lock = self._locks.get(bucket) if lock is None: lock = asyncio.Lock() if bucket is not None: self._locks[bucket] = lock if self.proxy is not None: params['proxy'] = self.proxy if self.proxy_auth is not None: params['proxy_auth'] = self.proxy_auth if 'json' in params: self.header['Content-Type'] = 'application/json' # Only json. params['data'] = json.dumps(params.pop('json')) if 'token' in params: self.header['Authorization'] = 'Bot ' + params.pop('token') try: reason: str = params.pop('reason') except KeyError: pass else: self.header['X-Audit-Log-Reason'] = quote(str(reason), safe='/ ') params['headers'] = self.header if not self._has_global.is_set(): await self._has_global.wait() await lock.acquire() with PadLock(lock) as padl: for tries in range(5): if files: for f in files: f.reset(seek=tries) if form: form_data = aiohttp.FormData() for kwargs in form: form_data.add_field(**kwargs) params['data'] = form_data try: async with self._session.request(method, url, **params) as r: d = await parse_tj(r) _log.debug( '< %s %s %s %s', method, url, params.get('data'), bucket ) try: remains = r.headers.get('X-RateLimit-Remaining') reset_after = r.headers.get('X-RateLimit-Reset-After') except KeyError: # Some endpoints don't give you these ratelimit headers # and so they will error out. pass if remains == '0' and r.status != 429: # the bucket was depleted padl.defer() _log.debug( 'A ratelimit Bucket was depleted. (bucket: %s, retry: %s)', bucket, float(reset_after), ) self.state.loop.call_later(float(reset_after), lock.release) if r.status == 429: if not r.headers.get('via') or isinstance(d, str): # handles couldflare bans raise RESTError(d) retry_in: float = d['retry_after'] _log.warning( 'The Rest Client seems to be ratelimited,' ' Retrying in: %.2f seconds. Handled with the bucket: %s', retry_in, bucket, ) is_global = d.get('global', False) if is_global: _log.debug( 'Global ratelimit was hit, retrying in %s', retry_in ) self._has_global.clear() await asyncio.sleep(retry_in) _log.debug( 'Finished waiting for the ratelimit, now retrying...' ) if not is_global: self._has_global.set() _log.debug('Global ratelimit has been depleted.') continue elif r.status == 403: raise Forbidden(d) elif r.status == 404: raise NotFound(d) elif r.status == 500: raise ServerError(d) elif 300 > r.status >= 200: _log.debug('> %s (bucket: %s)', d, r.headers.get('X-RateLimit-Bucket', None)) await self._session.close() return d elif r.status == 204: pass else: _log.error(r) except Exception as exc: raise RESTError(f'{exc}')
async def cdn(self, url) -> bytes: async with self._session.get(url) as r: if r.status == 200: return await r.read() elif r.status == 404: raise NotFound(f'{url} is not found!') elif r.status == 403: raise Forbidden('You are not allowed to see this image!') else: raise RESTError(r, 'asset recovery failed.') async def close(self) -> None: if self._session: await self._session.close() # Closes the session async def create_if_not_exists(self): if self._session is None or utils.MISSING: self._session = aiohttp.ClientSession() else: pass