deepghs / waifuc

Efficient Train Data Collector for Anime Waifu
https://deepghs.github.io/waifuc/
MIT License
290 stars 24 forks source link

dev(narugo): add parallel download for data sources #16

Closed narugo1992 closed 10 months ago

narugo1992 commented 11 months ago

The core code is in waifuc.source.web.ParallelWebDataSource.

It is based on SerializableParallelModule, NonSerializableParallelModule, Stopped in waifuc.utils (the unittests for them have passed).

The ParallelWebDataSource have the same api to WebDataSource, you can simply replace the base class of the data source classes to ParallelWebDataSource and add some more __init__ arguments.

Here is an example of it

import os.path
import re
from typing import Optional, Iterator, List, Tuple, Union

from ditk import logging
from hbutils.system import urlsplit
from requests.auth import HTTPBasicAuth

from waifuc.config.meta import __TITLE__, __VERSION__
from waifuc.export import SaveExporter
from waifuc.source.web import NoURL, ParallelWebDataSource
from waifuc.utils import get_requests_session, srequest

logging.try_init_root(logging.INFO)

class DanbooruLikeSource(ParallelWebDataSource):
    # just the same as the original DanbooruSource
    def __init__(self, tags: List[str], min_size: Optional[int] = 800, download_silent: bool = True,
                 username: Optional[str] = None, api_key: Optional[str] = None,
                 site_name: Optional[str] = 'danbooru', site_url: Optional[str] = 'https://danbooru.donmai.us/',
                 group_name: Optional[str] = None, max_workers: Optional[int] = None, serializable: bool = True):
        ParallelWebDataSource.__init__(self, group_name or site_name, None, download_silent, max_workers, serializable)
        self.session = get_requests_session(headers={
            "User-Agent": f"{__TITLE__}/{__VERSION__}",
            'Content-Type': 'application/json; charset=utf-8',
        })
        self.auth = HTTPBasicAuth(username, api_key) if username and api_key else None
        self.site_name, self.site_url = site_name, site_url
        self.tags = tags
        self.min_size = min_size

    def _get_data_from_raw(self, raw):
        return raw

    def _select_url(self, data):
        if self.min_size is not None and "media_asset" in data and "variants" in data["media_asset"]:
            variants = data["media_asset"]["variants"]
            width, height, url = None, None, None
            for item in variants:
                if 'width' in item and 'height' in item and \
                        item['width'] >= self.min_size and item['height'] >= self.min_size:
                    if url is None or item['width'] < width:
                        width, height, url = item['width'], item['height'], item['url']

            if url is not None:
                return url

        if 'file_url' not in data:
            raise NoURL

        return data['file_url']

    def _get_tags(self, data):
        return re.split(r'\s+', data["tag_string"])

    def _iter_data(self) -> Iterator[Tuple[Union[str, int], str, dict]]:
        page = 1
        while True:
            resp = srequest(self.session, 'GET', f'{self.site_url}/posts.json', params={
                "format": "json",
                "limit": "100",
                "page": str(page),
                "tags": ' '.join(self.tags),
            }, auth=self.auth)
            resp.raise_for_status()
            page_items = self._get_data_from_raw(resp.json())
            if not page_items:
                break

            for data in page_items:
                try:
                    url = self._select_url(data)
                except NoURL:
                    continue

                _, ext_name = os.path.splitext(urlsplit(url).filename)
                filename = f'{self.group_name}_{data["id"]}{ext_name}'
                meta = {
                    self.site_name: data,
                    'group_id': f'{self.group_name}_{data["id"]}',
                    'filename': filename,
                    'tags': {key: 1.0 for key in self._get_tags(data)}
                }
                yield data['id'], url, meta

            page += 1

class DanbooruSource(DanbooruLikeSource):
    def __init__(self, tags: List[str],
                 min_size: Optional[int] = 800, download_silent: bool = True,
                 username: Optional[str] = None, api_key: Optional[str] = None,
                 group_name: Optional[str] = None, max_workers: Optional[int] = None, serializable: bool = True):
        DanbooruLikeSource.__init__(self, tags, min_size, download_silent, username, api_key,
                                    'danbooru', 'https://danbooru.donmai.us/', group_name, max_workers, serializable)

if __name__ == '__main__':
    s = DanbooruSource(
        ['surtr_(arknights)'],
        max_workers=4,
        serializable=False,  # the items will be in order when True
        download_silent=False,  # hide the download progress bar when True
    )
    s[:100].export(SaveExporter('test_zerochan_imgs', no_meta=True))  # download the first 100 images
    s.cleanup()  # TODO: we don't want this line is possible
narugo1992 commented 10 months ago

closed due to not planned