import asyncio
import copy
import logging
import signal
import json
from datetime import datetime
from hypercorn.asyncio import serve
from hypercorn.config import Config as HyperConfig
from nacl.exceptions import BadSignatureError
from nacl.signing import VerifyKey
from quart import Quart, request, abort
from quart import Response as QuartResponse
from quart.logging import default_handler
from quart.utils import MustReloadError, restart
from typing import Any, TYPE_CHECKING
from . import utils
from .commands import Command, SubGroup
from .enums import InteractionType
from .errors import CheckFailed
from .response import BaseResponse, Ping, MessageResponse, EmptyResponse
if TYPE_CHECKING:
from .client import Client
from .context import Context
_log = logging.getLogger(__name__)
__all__ = (
"DiscordHTTP",
)
def _cancel_all_tasks(loop: asyncio.AbstractEventLoop) -> None:
""" Used by Quart to cancel all tasks on shutdown. """
tasks = [
task for task in asyncio.all_tasks(loop)
if not task.done()
]
if not tasks:
return
for task in list(tasks):
task.cancel()
if task.get_coro().__name__ == "_windows_signal_support": # type: ignore
tasks.remove(task)
loop.run_until_complete(
asyncio.gather(*tasks, return_exceptions=True)
)
for task in tasks:
if not task.cancelled() and task.exception() is not None:
loop.call_exception_handler({
"message": "unhandled exception during shutdown",
"exception": task.exception(),
"task": task
})
[docs]
class DiscordHTTP(Quart):
def __init__(self, *, client: "Client"):
"""
This serves as the fundemental HTTP server for Discord Interactions
We recommend to not touch this class, unless you know what you're doing
"""
self.uptime: datetime = utils.utcnow()
self.bot: "Client" = client
# Aliases
self.loop = self.bot.loop
self.debug_events = self.bot.debug_events
super().__init__(__name__)
# Change some of the default settings
self.config["JSONIFY_PRETTYPRINT_REGULAR"] = False
self.config["JSON_SORT_KEYS"] = False
# Remove Quart's default logging handler
_quart_log = logging.getLogger("quart.app")
_quart_log.removeHandler(default_handler)
_quart_log.setLevel(logging.CRITICAL)
async def _validate_request(self) -> None:
"""
Used to validate requests sent by Discord Webhooks
This should NOT be modified, unless you know what you're doing
"""
if not self.bot.public_key:
return abort(401, "invalid public key")
verify_key = VerifyKey(bytes.fromhex(self.bot.public_key))
signature: str = request.headers.get("X-Signature-Ed25519", "")
timestamp: str = request.headers.get("X-Signature-Timestamp", "")
try:
data = await request.data
body = data.decode("utf-8")
verify_key.verify(
f"{timestamp}{body}".encode(),
bytes.fromhex(signature)
)
except BadSignatureError:
abort(401, "invalid request signature")
except Exception:
abort(400, "invalid request body")
def _dig_subcommand(
self,
cmd: Command | SubGroup,
data: dict
) -> tuple[Command | None, list[dict]]:
"""
Used to dig through subcommands to execute correct command/autocomplete
"""
data_options: list[dict] = data["data"].get("options", [])
while isinstance(cmd, SubGroup):
find_next_step = next((
g for g in data_options
if g.get("name", None) and not g.get("value", None)
), None)
if not find_next_step:
return abort(400, "invalid command")
cmd = cmd.subcommands.get(find_next_step["name"], None) # type: ignore
if not cmd:
_log.warning(
f"Unhandled subcommand: {find_next_step['name']} "
"(not found in local command list)"
)
return abort(404, "command not found")
data_options = find_next_step.get("options", [])
return cmd, data_options
def _handle_ack_ping(
self,
ctx: "Context",
data: dict
) -> dict:
""" Used to handle ACK ping """
_ping = Ping(state=self.bot.state, data=data)
if self.bot.has_any_dispatch("ping"):
self.bot.dispatch("ping", _ping)
_log.debug(f"Discord Interactions ACK recieved ({_ping.id})")
return ctx.response.pong()
async def _handle_application_command(
self,
ctx: "Context",
data: dict
) -> QuartResponse | dict:
""" Used to handle application commands """
_log.debug("Received slash command, processing...")
command_name = data["data"]["name"]
cmd = self.bot.commands.get(command_name, None)
if not cmd:
_log.warning(
f"Unhandeled command: {command_name} "
"(not found in local command list)"
)
return QuartResponse(
"command not found",
status=404
)
cmd, _ = self._dig_subcommand(cmd, data)
# Now that the command is found, let context know about it
ctx.command = cmd
try:
# But first, check global checks
await self.bot._run_global_checks(ctx)
# Now run the command itself
payload = await cmd._make_context_and_run(
context=ctx
)
if isinstance(payload, EmptyResponse):
return QuartResponse("", status=202)
return QuartResponse(
payload.to_multipart(),
content_type=payload.content_type
)
except Exception as e:
if self.bot.has_any_dispatch("interaction_error"):
self.bot.dispatch("interaction_error", ctx, e)
else:
_log.error(
f"Error while running command {cmd.name}",
exc_info=e
)
_send_error = self.error_messages(ctx, e)
if _send_error and isinstance(_send_error, BaseResponse):
return _send_error.to_dict()
return abort(500)
async def _handle_interaction(
self,
ctx: "Context",
data: dict
) -> QuartResponse | dict:
""" Used to handle interactions """
_log.debug("Received interaction, processing...")
_custom_id = data["data"]["custom_id"]
try:
local_view = None
if (
local_view is None and
ctx.custom_id
):
local_view = self.bot._view_storage.get(
ctx.custom_id, None
)
if (
local_view is None and
ctx.message
):
local_view = self.bot._view_storage.get(
ctx.message.id, None
)
if not local_view and ctx.message.interaction:
local_view = self.bot._view_storage.get(
ctx.message.interaction.id, None
)
if local_view:
payload = await local_view.callback(ctx)
return QuartResponse(
payload.to_multipart(),
content_type=payload.content_type
)
intreact = self.bot.find_interaction(_custom_id)
if not intreact:
_log.debug(
"Unhandled interaction recieved "
f"(custom_id: {_custom_id})"
)
return QuartResponse(
"interaction not found",
status=404
)
payload = await intreact.run(ctx)
return QuartResponse(
payload.to_multipart(),
content_type=payload.content_type
)
except Exception as e:
if self.bot.has_any_dispatch("interaction_error"):
self.bot.dispatch("interaction_error", ctx, e)
else:
_log.error(
f"Error while running interaction {_custom_id}",
exc_info=e
)
return abort(500)
async def _handle_autocomplete(
self,
ctx: "Context",
data: dict
) -> QuartResponse | dict:
""" Used to handle autocomplete interactions """
_log.debug("Received autocomplete interaction, processing...")
command_name = data.get("data", {}).get("name", None)
cmd = self.bot.commands.get(command_name)
try:
if not cmd:
_log.warning(f"Unhandled autocomplete recieved (name: {command_name})")
return QuartResponse(
"command not found",
status=404
)
cmd, data_options = self._dig_subcommand(cmd, data)
find_focused = next((
x for x in data_options
if x.get("focused", False)
), None)
if not find_focused:
_log.warning(
"Failed to find focused option in autocomplete "
f"(cmd name: {command_name})"
)
return QuartResponse(
"focused option not found",
status=400
)
return await cmd.run_autocomplete(
ctx, find_focused["name"], find_focused["value"]
)
except Exception as e:
if self.bot.has_any_dispatch("interaction_error"):
self.bot.dispatch("interaction_error", ctx, e)
else:
_log.error(
f"Error while running autocomplete {cmd.name}",
exc_info=e
)
return abort(500)
async def _index_interactions_endpoint(
self
) -> QuartResponse | dict:
"""
The main function to handle all HTTP requests sent by Discord
Please do not touch this function, unless you know what you're doing
"""
await self._validate_request()
data = await request.json
if self.debug_events:
self.bot.dispatch(
"raw_interaction",
copy.deepcopy(data)
)
context = self.bot._context(self.bot, data)
data_type = data.get("type", -1)
match data_type:
case InteractionType.ping:
return self._handle_ack_ping(context, data)
case InteractionType.application_command:
return await self._handle_application_command(
context, data
)
case x if x in (
InteractionType.message_component,
InteractionType.modal_submit
):
return await self._handle_interaction(
context, data
)
case InteractionType.application_command_autocomplete:
return await self._handle_autocomplete(
context, data
)
case _: # Unknown
_log.debug(f"Unhandled interaction recieved (type: {data_type})")
return abort(400, "invalid request body")
[docs]
def error_messages(
self,
ctx: "Context",
e: Exception
) -> MessageResponse | None:
"""
Used to return error messages to Discord.
By default, it will only cover CheckFailed errors.
You can overwrite this function to return your own error messages.
Parameters
----------
ctx: `Context`
The context of the command
e: `Exception`
The exception that was raised
Returns
-------
`Optional[MessageResponse]`
The message response provided by the library error handler
"""
if isinstance(e, CheckFailed):
return ctx.response.send_message(
content=str(e),
ephemeral=True
)
[docs]
async def index_ping(self) -> tuple[dict, int] | dict:
"""
Used to ping the interaction url, to check if it's working
You can overwrite this function to return your own data as well.
Remember that it must return `dict`
"""
if not self.bot.is_ready():
return {"error": "bot is not ready yet"}, 503
return {
"@me": {
"id": self.bot.user.id,
"username": self.bot.user.name,
"discriminator": self.bot.user.discriminator,
"created_at": str(self.bot.user.created_at.isoformat()),
},
"last_reboot": {
"datetime": str(self.uptime.astimezone().isoformat()),
"timedelta": str(utils.utcnow() - self.uptime),
"unix": int(self.uptime.timestamp()),
}
}
[docs]
def jsonify(
self,
data: dict,
*,
status: int = 200,
sort_keys: bool = False,
indent: int | None = None,
) -> QuartResponse:
"""
Force Quart to respond with JSON the way you like it
Parameters
----------
data: `dict`
The data to respond with
status: `int`
The status code to respond with
sort_keys: `bool`
Whether to sort the keys or not
indent: `int | None`
If the JSON should be indented on response
Returns
-------
`QuartResponse`
The response object
"""
return QuartResponse(
json.dumps(data, sort_keys=sort_keys, indent=indent),
headers={"Content-Type": "application/json"},
status=status,
)
[docs]
def start(
self,
*,
host: str = "127.0.0.1",
port: int = 8080
) -> None:
if not self.bot.disable_default_get_path:
self.add_url_rule(
"/",
"ping",
self.index_ping,
methods=["GET"]
)
self.add_url_rule(
"/",
"index",
self._index_interactions_endpoint,
methods=["POST"]
)
try:
_log.info(f"Serving on http://{host}:{port}")
self.run(host=host, port=port, loop=self.loop)
except KeyboardInterrupt:
pass # Just don't bother showing errors...
[docs]
def run(
self,
host: str,
port: int,
loop: asyncio.AbstractEventLoop
) -> None:
""" ## Do NOT use this function, use `start` instead """
loop.set_debug(False)
shutdown_event = asyncio.Event()
def _signal_handler(*_: Any) -> None:
shutdown_event.set()
for signal_name in {"SIGINT", "SIGTERM", "SIGBREAK"}:
if hasattr(signal, signal_name):
try:
loop.add_signal_handler(
getattr(signal, signal_name),
_signal_handler
)
except NotImplementedError:
# Add signal handler may not be implemented on Windows
signal.signal(
getattr(signal, signal_name),
_signal_handler
)
server_name = self.config.get("SERVER_NAME")
sn_host = None
sn_port = None
if server_name is not None:
sn_host, _, sn_port = server_name.partition(":")
if host is None:
host = sn_host or "127.0.0.1"
if port is None:
port = int(sn_port or "8080")
task = self.run_task(
host=host,
port=port,
shutdown_trigger=shutdown_event.wait,
)
tasks = [loop.create_task(task)]
reload_ = False
try:
loop.run_until_complete(asyncio.gather(*tasks))
except MustReloadError:
reload_ = True
except KeyboardInterrupt:
pass
finally:
try:
_cancel_all_tasks(loop)
loop.run_until_complete(loop.shutdown_asyncgens())
finally:
asyncio.set_event_loop(None)
loop.close()
if reload_:
restart()
[docs]
def run_task(
self,
host: str = "127.0.0.1",
port: int = 8080,
shutdown_trigger=None
):
""" ## Do NOT use this function, use `start` instead """
config = HyperConfig()
config.access_log_format = "%(h)s %(r)s %(s)s %(b)s %(D)s"
config.accesslog = None
config.bind = [f"{host}:{port}"]
config.ca_certs = None
config.certfile = None
config.debug = False
config.errorlog = None
config.keyfile = None
return serve(
self,
config,
shutdown_trigger=shutdown_trigger
)