deepghs / imgutils

A convenient and user-friendly anime-style image data processing library that integrates various advanced anime-style image processing models
https://dghs-imgutils.deepghs.org/
MIT License
218 stars 16 forks source link

dev(narugo): embedding inverse #94

Closed narugo1992 closed 6 months ago

narugo1992 commented 6 months ago
import numpy as np
from huggingface_hub import hf_hub_download

from imgutils.tagging.wd14 import inv_wd14_by_predictions

model_name = 'ConvNext'
scale: int = 2000

samples = np.load(hf_hub_download(
    repo_id='deepghs/wd14_tagger_inversion',
    repo_type='dataset',
    filename=f'{model_name}/samples_{scale}.npz',
))

predictions, embeddings = samples['preds'], samples['embs']

inv_embs = inv_wd14_by_predictions(predictions, model_name=model_name, norm=True)
print('Inversed embeddings:', inv_embs)
print(inv_embs.shape)

expected_embs = embeddings / np.linalg.norm(embeddings, axis=-1)[..., None]
print('Expected embeddings:', expected_embs)
print(expected_embs.shape)

sims = (inv_embs * expected_embs).sum(axis=-1)
print('Similarities:', sims)
sim = sims.mean()
print('Mean Similarity:', sim)  # 0.997
codecov[bot] commented 6 months ago

Codecov Report

Attention: Patch coverage is 33.33333% with 10 lines in your changes are missing coverage. Please review.

Project coverage is 98.80%. Comparing base (803ee23) to head (d3dba8f). Report is 2 commits behind head on main.

:exclamation: Current head d3dba8f differs from pull request most recent head 195c8fa. Consider uploading reports for the commit 195c8fa to get more accurate results

Files Patch % Lines
imgutils/tagging/wd14.py 33.33% 10 Missing :warning:
Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #94 +/- ## ========================================== - Coverage 99.16% 98.80% -0.36% ========================================== Files 93 93 Lines 2752 2766 +14 ========================================== + Hits 2729 2733 +4 - Misses 23 33 +10 ``` | [Flag](https://app.codecov.io/gh/deepghs/imgutils/pull/94/flags?src=pr&el=flags&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=deepghs) | Coverage Δ | | |---|---|---| | [unittests](https://app.codecov.io/gh/deepghs/imgutils/pull/94/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=deepghs) | `98.80% <33.33%> (-0.36%)` | :arrow_down: | Flags with carried forward coverage won't be shown. [Click here](https://docs.codecov.io/docs/carryforward-flags?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=deepghs#carryforward-flags-in-the-pull-request-comment) to find out more.

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.