yepcord / server

Unofficial discord backend implementation in python.
GNU Affero General Public License v3.0
2 stars 1 forks source link

raise exception instead #135

Closed github-actions[bot] closed 8 months ago

github-actions[bot] commented 1 year ago

https://github.com/yepcord/server/blob/dc4d1e6f656986d6032fbf775096df2048cbd10e/src/rest_api/models/channels.py#L322


from __future__ import annotations

from time import mktime
from typing import Optional, List

from dateutil.parser import parse as dparse
from pydantic import BaseModel, validator, Field

from ..utils import makeEmbedError
from ...yepcord.enums import ChannelType
from ...yepcord.errors import EmbedErr, InvalidDataErr, Errors
from ...yepcord.utils import validImage, getImage

class ChannelUpdate(BaseModel):
    icon: Optional[str] = "" # Only for GROUP_DM channel
    owner_id: Optional[int] = Field(default=None, alias="owner") # Only for GROUP_DM channel
    name: Optional[str] = None # For any channel (except DM)
    # For guild channels:
    type: Optional[int] = None
    position: Optional[int] = None
    topic: Optional[str] = None
    nsfw: Optional[bool] = None
    rate_limit: Optional[int] = Field(default=None, alias="rate_limit_per_user")
    bitrate: Optional[int] = None
    user_limit: Optional[int] = None
    #permission_overwrites: List[PermissionOverwriteModel] = []
    parent_id: Optional[int] = None
    #rtc_region: Optional[str] = None
    video_quality_mode: Optional[int] = None
    default_auto_archive: Optional[int] = Field(default=None, alias="default_auto_archive_duration")
    flags: Optional[int] = None
    # Only for threads:
    auto_archive_duration: Optional[int] = None
    locked: Optional[bool] = None
    invitable: Optional[bool] = None

    class Config:
        allow_population_by_field_name = True

    @validator("name")
    def validate_name(cls, value: Optional[str]):
        if value is not None:
            value = value[:100]
        return value

    @validator("icon")
    def validate_icon(cls, value: Optional[str]):
        if value is not None:
            if not (img := getImage(value)) or not validImage(img):
                value = None
        return value

    @validator("topic")
    def validate_topic(cls, value: Optional[str]):
        if value is not None:
            value = value[:1024]
        return value

    @validator("rate_limit")
    def validate_rate_limit(cls, value: Optional[int]):
        if value is not None:
            if value < 0: value = 0
            if value > 21600: value = 21600
        return value

    @validator("bitrate")
    def validate_bitrate(cls, value: Optional[int]):
        if value is not None:
            if value < 8000: value = 8000
        return value

    @validator("user_limit")
    def validate_user_limit(cls, value: Optional[int]):
        if value is not None:
            if value < 0: value = 0
            if value > 99: value = 99
        return value

    @validator("video_quality_mode")
    def validate_video_quality_mode(cls, value: Optional[int]):
        if value is not None:
            if value not in (0, 1): value = None
        return value

    @validator("default_auto_archive", allow_reuse=True)
    @validator("auto_archive_duration", allow_reuse=True)
    def validate_video_quality_mode(cls, value: Optional[int]):
        ALLOWED_DURATIONS = (60, 1440, 4320, 10080)
        if value is not None:
            if value not in ALLOWED_DURATIONS:
                value = min(ALLOWED_DURATIONS, key=lambda x: abs(x - value)) # Take closest
        return value

    def to_json(self, channel_type: int) -> dict:
        if channel_type == ChannelType.GROUP_DM:
            return self.dict(include={"name", "icon", "owner_id"}, exclude_defaults=True)
        elif channel_type == ChannelType.GUILD_CATEGORY:
            return self.dict(include={"name", "position"}, exclude_defaults=True)
        elif channel_type == ChannelType.GUILD_TEXT:
            return self.dict(
                # TODO: add `type` when GUILD_NEWS channels will be added
                include={"name", "position", "topic", "nsfw", "rate_limit", "parent_id", "default_auto_archive"},
                exclude_defaults=True)
        elif channel_type == ChannelType.GUILD_VOICE:
            return self.dict(include={"name", "position", "nsfw", "bitrate", "user_limit", "parent_id",
                                      "video_quality_mode"}, exclude_defaults=True)

class PermissionOverwriteModel(BaseModel):
    id: int
    type: int
    allow: int
    deny: int

    def dict(self, *args, **kwargs) -> dict:
        kwargs["include"] = {"type", "allow", "deny"}
        return super().dict(*args, **kwargs)

class EmbedFooter(BaseModel):
    text: Optional[str] = None
    icon_url: Optional[str] = None

    @validator("text")
    def validate_text(cls, value: Optional[str]):
        if value is not None:
            if len(value) > 2048:
                raise EmbedErr(makeEmbedError(27, f"footer.text", {"length": "2048"}))
        return value

    @validator("icon_url")
    def validate_icon_url(cls, value: Optional[str]):
        if value is not None:
            if (scheme := value.split(":")[0]) not in ["http", "https"]:
                raise EmbedErr(makeEmbedError(24, f"footer.icon_url", {"scheme": scheme}))
        return value

    def dict(self, *args, **kwargs) -> dict:
        kwargs["exclude_defaults"] = True
        return super().dict(*args, **kwargs)

class EmbedImage(BaseModel):
    url: Optional[str] = None
    width: Optional[int] = None
    height: Optional[int] = None

    @validator("url")
    def validate_url(cls, value: Optional[str]):
        if value is not None:
            if (scheme := value.split(":")[0]) not in ["http", "https"]:
                raise EmbedErr(makeEmbedError(24, f"url", {"scheme": scheme}))
        return value

    def dict(self, *args, **kwargs) -> dict:
        kwargs["exclude_defaults"] = True
        return super().dict(*args, **kwargs)

class EmbedAuthor(BaseModel):
    name: Optional[str] = None
    url: Optional[str] = None
    icon_url: Optional[int] = None

    @validator("name")
    def validate_name(cls, value: Optional[str]):
        if value is not None:
            if len(value) > 256:
                raise EmbedErr(makeEmbedError(27, f"author.name", {"length": "256"}))
        return value

    @validator("url", allow_reuse=True)
    @validator("icon_url", allow_reuse=True)
    def validate_url(cls, value: Optional[str]):
        if value is not None:
            if (scheme := value.split(":")[0]) not in ["http", "https"]:
                raise EmbedErr(makeEmbedError(24, f"url", {"scheme": scheme}))
        return value

    def dict(self, *args, **kwargs) -> dict:
        kwargs["exclude_defaults"] = True
        return super().dict(*args, **kwargs)

class EmbedField(BaseModel):
    name: Optional[str] = None
    value: Optional[str] = None

    @validator("name")
    def validate_name(cls, value: Optional[str]):
        if not value:
            raise EmbedErr(makeEmbedError(23, f"fields.name"))
        if len(value) > 256:
            raise EmbedErr(makeEmbedError(27, f"fields.name", {"length": "256"}))

    @validator("value")
    def validate_value(cls, value: Optional[str]):
        if not value:
            raise EmbedErr(makeEmbedError(23, f"fields.value"))
        if len(value) > 1024:
            raise EmbedErr(makeEmbedError(27, f"fields.value", {"length": "1024"}))

class EmbedModel(BaseModel):
    title: str
    type: Optional[str] = None
    description: Optional[str] = None
    url: Optional[str] = None
    timestamp: Optional[str] = None
    color: Optional[int] = None
    footer: Optional[EmbedFooter] = None
    image: Optional[EmbedImage] = None
    thumbnail: Optional[EmbedImage] = None
    video: Optional[EmbedImage] = None
    author: Optional[EmbedAuthor] = None
    fields: List[EmbedField] = Field(default_factory=list)

    @validator("title")
    def validate_title(cls, value: str):
        if not value: raise EmbedErr(makeEmbedError(23))
        if len(value) > 256:
            raise EmbedErr(makeEmbedError(27, f"title", {"length": "256"}))
        return value

    @validator("type")
    def validate_type(cls, value: Optional[str]):
        return "rich"

    @validator("description")
    def validate_description(cls, value: Optional[str]):
        if value is not None:
            if len(value) > 4096:
                raise EmbedErr(makeEmbedError(27, f"description", {"length": "4096"}))
        return value

    @validator("url")
    def validate_url(cls, value: Optional[str]):
        if value is not None:
            if (scheme := value.split(":")[0]) not in ["http", "https"]:
                raise EmbedErr(makeEmbedError(24, f"url", {"scheme": scheme}))
        return value

    @validator("timestamp")
    def validate_timestamp(cls, value: Optional[str]):
        if value is not None:
            try:
                mktime(dparse(value).timetuple())
            except ValueError:
                raise EmbedErr(makeEmbedError(25, f"timestamp", {"value": value}))
        return value

    @validator("color")
    def validate_color(cls, value: Optional[int]):
        if value is not None:
            if value > 0xffffff or value < 0:
                raise EmbedErr(makeEmbedError(26, f"color"))
        return value

    @validator("footer")
    def validate_footer(cls, value: Optional[EmbedFooter]):
        if value is not None:
            if not value.text:
                value = None
        return value

    @validator("image", allow_reuse=True)
    @validator("thumbnail", allow_reuse=True)
    @validator("video", allow_reuse=True)
    def validate_image(cls, value: Optional[EmbedImage]):
        if value is not None:
            if not value.url:
                value = None
        return value

    @validator("author")
    def validate_author(cls, value: Optional[EmbedAuthor]):
        if value is not None:
            if not value.name:
                value = None
        return value

    @validator("fields")
    def validate_fields(cls, value: List[EmbedField]):
        if len(value) > 25:
            value = value[:25]

    def dict(self, *args, **kwargs) -> dict:
        kwargs["exclude_defaults"] = True
        return super().dict(*args, **kwargs)

class MessageCreate(BaseModel):
    content: Optional[str] = None
    nonce: Optional[str] = None
    embeds: List[EmbedModel] = Field(default_factory=list)
    message_reference: Optional[int] = None
    flags: Optional[int] = None

    def __init__(self, **data):
        if "message_reference" in data:
            data["message_reference"] = data["message_reference"]["message_id"]
        super().__init__(**data)

    @validator("content")
    def validate_content(cls, value: Optional[str]):
        if value is not None:
            if len(value) > 2000: value = value[:2000] # TODO: raise exception instead
        return value

    def to_json(self) -> dict:
        return self.dict(exclude_defaults=True)

class MessageUpdate(BaseModel):
    content: Optional[str] = None
    embeds: List[EmbedModel] = Field(default_factory=list)

    @validator("content")
    def validate_content(cls, value: Optional[str]):
        if value is not None:
            if len(value) > 2000: value = value[:2000]  # TODO: raise exception instead
        return value

    def to_json(self) -> dict:
        return self.dict(exclude_defaults=True)

class InviteCreate(BaseModel):
    max_age: Optional[int] = 86400
    max_uses: Optional[int] = 0

class WebhookCreate(BaseModel):
    name: Optional[str] = None

    @validator("name")
    def validate_name(cls, value: Optional[str]):
        if not value:
            raise InvalidDataErr(400,
                                 Errors.make(50035,
                                             {"name": {"code": "BASE_TYPE_REQUIRED", "message": "Required field"}}))
        return value