mit-han-lab / fastcomposer

[IJCV] FastComposer: Tuning-Free Multi-Subject Image Generation with Localized Attention
https://fastcomposer.mit.edu
MIT License
669 stars 38 forks source link

environment.yml #26

Closed malteprinzler closed 1 year ago

malteprinzler commented 1 year ago

Please find the (inofficial) environment.yml below. I had some issues setting up the environment correctly myself, and hope this will help someone else.

DISCLAIMER: I have not yet checked if the model converges to the same scores as provided in the paper, but at least training works with batch size 16 on 80GB A100 GPUs.

environment.yml:

name: fastcomposer
channels:
  - pytorch
  - nvidia
  - conda-forge
  - defaults
dependencies:
  - _libgcc_mutex=0.1=conda_forge
  - _openmp_mutex=4.5=2_gnu
  - abseil-cpp=20211102.0=hd4dd3e8_0
  - accelerate=0.19.0=pyhd8ed1ab_0
  - aiofiles=22.1.0=py310h06a4308_0
  - aiohttp=3.8.5=py310h5eee18b_0
  - aiosignal=1.2.0=pyhd3eb1b0_0
  - altair=5.0.1=py310h06a4308_0
  - annotated-types=0.6.0=pyhd8ed1ab_0
  - anyio=3.7.1=pyhd8ed1ab_0
  - arrow-cpp=11.0.0=h374c478_2
  - async-timeout=4.0.2=py310h06a4308_0
  - attrs=23.1.0=py310h06a4308_0
  - aws-c-common=0.6.8=h5eee18b_1
  - aws-c-event-stream=0.1.6=h6a678d5_6
  - aws-checksums=0.1.11=h5eee18b_2
  - aws-sdk-cpp=1.8.185=h721c034_1
  - blas=1.0=mkl
  - boost-cpp=1.82.0=hdb19cb5_2
  - bottleneck=1.3.5=py310ha9d4c09_0
  - brotli=1.0.9=h5eee18b_7
  - brotli-bin=1.0.9=h5eee18b_7
  - brotli-python=1.0.9=py310h6a678d5_7
  - bzip2=1.0.8=h7b6447c_0
  - c-ares=1.19.1=h5eee18b_0
  - ca-certificates=2023.11.17=hbcca054_0
  - certifi=2023.11.17=pyhd8ed1ab_0
  - cffi=1.16.0=py310h5eee18b_0
  - charset-normalizer=2.0.4=pyhd3eb1b0_0
  - click=8.1.7=py310h06a4308_0
  - colorama=0.4.6=pyhd8ed1ab_0
  - contourpy=1.2.0=py310hdb19cb5_0
  - cryptography=41.0.3=py310hdda0065_0
  - cuda-cudart=11.7.99=0
  - cuda-cupti=11.7.101=0
  - cuda-libraries=11.7.1=0
  - cuda-nvrtc=11.7.99=0
  - cuda-nvtx=11.7.91=0
  - cuda-runtime=11.7.1=0
  - cycler=0.11.0=pyhd3eb1b0_0
  - datasets=2.12.0=py310h06a4308_0
  - diffusers=0.16.1=pyhd8ed1ab_0
  - dill=0.3.6=py310h06a4308_0
  - exceptiongroup=1.0.4=py310h06a4308_0
  - fastapi=0.103.2=pyhd8ed1ab_0
  - ffmpeg=4.3=hf484d3e_0
  - ffmpy=0.3.0=pyhb6f538c_0
  - filelock=3.13.1=pyhd8ed1ab_0
  - fonttools=4.25.0=pyhd3eb1b0_0
  - freetype=2.12.1=h4a9f257_0
  - frozenlist=1.4.0=py310h5eee18b_0
  - fsspec=2023.10.0=pyhca7485f_0
  - ftfy=6.1.3=pyhd8ed1ab_0
  - gflags=2.2.2=he6710b0_0
  - giflib=5.2.1=h5eee18b_3
  - glog=0.5.0=h2531618_0
  - gmp=6.2.1=h295c915_3
  - gnutls=3.6.15=he1e5248_0
  - gradio=3.50.2=pyhd8ed1ab_0
  - gradio-client=0.6.1=pyhd8ed1ab_0
  - grpc-cpp=1.48.2=he1ff14a_1
  - h11=0.12.0=pyhd3eb1b0_0
  - h2=4.0.0=py310h06a4308_3
  - hpack=4.0.0=py_0
  - httpcore=0.15.0=py310h06a4308_0
  - httpx=0.23.0=py310h06a4308_0
  - huggingface_hub=0.19.4=pyhd8ed1ab_0
  - hyperframe=6.0.1=pyhd3eb1b0_0
  - icu=73.1=h6a678d5_0
  - idna=3.4=py310h06a4308_0
  - importlib-metadata=6.8.0=pyha770c72_0
  - importlib_resources=6.1.0=py310h06a4308_0
  - intel-openmp=2023.1.0=hdb19cb5_46306
  - jinja2=3.1.2=py310h06a4308_0
  - joblib=1.2.0=py310h06a4308_0
  - jpeg=9e=h5eee18b_1
  - jsonschema=4.19.2=py310h06a4308_0
  - jsonschema-specifications=2023.7.1=py310h06a4308_0
  - kiwisolver=1.4.4=py310h6a678d5_0
  - krb5=1.20.1=h143b758_1
  - lame=3.100=h7b6447c_0
  - lcms2=2.12=h3be6417_0
  - ld_impl_linux-64=2.38=h1181459_1
  - lerc=3.0=h295c915_0
  - libboost=1.82.0=h109eef0_2
  - libbrotlicommon=1.0.9=h5eee18b_7
  - libbrotlidec=1.0.9=h5eee18b_7
  - libbrotlienc=1.0.9=h5eee18b_7
  - libcublas=11.10.3.66=0
  - libcufft=10.7.2.124=h4fbf590_0
  - libcufile=1.8.1.2=0
  - libcurand=10.3.4.101=0
  - libcurl=8.4.0=h251f7ec_0
  - libcusolver=11.4.0.1=0
  - libcusparse=11.7.4.91=0
  - libdeflate=1.17=h5eee18b_1
  - libedit=3.1.20221030=h5eee18b_0
  - libev=4.33=h7f8727e_1
  - libevent=2.1.12=hdbd6064_1
  - libffi=3.4.4=h6a678d5_0
  - libgcc-ng=13.2.0=h807b86a_3
  - libgomp=13.2.0=h807b86a_3
  - libiconv=1.16=h7f8727e_2
  - libidn2=2.3.4=h5eee18b_0
  - libnghttp2=1.57.0=h2d74bed_0
  - libnpp=11.7.4.75=0
  - libnvjpeg=11.8.0.2=0
  - libpng=1.6.39=h5eee18b_0
  - libprotobuf=3.20.3=he621ea3_0
  - libssh2=1.10.0=hdbd6064_2
  - libstdcxx-ng=11.2.0=h1234567_1
  - libtasn1=4.19.0=h5eee18b_0
  - libthrift=0.15.0=h1795dd8_2
  - libtiff=4.5.1=h6a678d5_0
  - libunistring=0.9.10=h27cfd23_0
  - libuuid=1.41.5=h5eee18b_0
  - libwebp=1.3.2=h11a3e52_0
  - libwebp-base=1.3.2=h5eee18b_0
  - lz4-c=1.9.4=h6a678d5_0
  - markdown-it-py=2.2.0=py310h06a4308_1
  - markupsafe=2.1.1=py310h7f8727e_0
  - matplotlib-base=3.8.0=py310h1128e8f_0
  - mdurl=0.1.0=py310h06a4308_0
  - mkl=2023.1.0=h213fc3f_46344
  - mkl-service=2.4.0=py310h5eee18b_1
  - mkl_fft=1.3.8=py310h5eee18b_0
  - mkl_random=1.2.4=py310hdb19cb5_0
  - multidict=6.0.2=py310h5eee18b_0
  - multiprocess=0.70.14=py310h06a4308_0
  - munkres=1.1.4=py_0
  - ncurses=6.4=h6a678d5_0
  - nettle=3.7.3=hbbd107a_1
  - numexpr=2.8.7=py310h85018f9_0
  - numpy=1.26.0=py310h5f9d8c6_0
  - numpy-base=1.26.0=py310hb5e798b_0
  - openai-clip=1.0.1=pyhd8ed1ab_0
  - openh264=2.1.1=h4ff587b_0
  - openjpeg=2.4.0=h3ad879b_0
  - openssl=3.2.0=hd590300_0
  - orc=1.7.4=hb3bc3d3_1
  - orjson=3.8.8=py310h52d8a92_0
  - packaging=23.2=pyhd8ed1ab_0
  - pandas=2.1.1=py310h1128e8f_0
  - pillow=10.0.1=py310ha6cbd5a_0
  - pip=23.3=py310h06a4308_0
  - psutil=5.9.0=py310h5eee18b_0
  - pyarrow=11.0.0=py310h468efa6_1
  - pycparser=2.21=pyhd3eb1b0_0
  - pydantic=2.5.1=pyhd8ed1ab_0
  - pydantic-core=2.14.3=py310hcb5633a_0
  - pydub=0.25.1=pyhd8ed1ab_0
  - pygments=2.15.1=py310h06a4308_1
  - pyopenssl=23.2.0=py310h06a4308_0
  - pyparsing=3.0.9=py310h06a4308_0
  - pysocks=1.7.1=py310h06a4308_0
  - python=3.10.13=h955ad1f_0
  - python-dateutil=2.8.2=pyhd3eb1b0_0
  - python-multipart=0.0.6=py310h06a4308_0
  - python-tzdata=2023.3=pyhd3eb1b0_0
  - python-xxhash=2.0.2=py310h5eee18b_1
  - python_abi=3.10=2_cp310
  - pytorch=1.13.1=py3.10_cuda11.7_cudnn8.5.0_0
  - pytorch-cuda=11.7=h778d358_5
  - pytorch-mutex=1.0=cuda
  - pytz=2023.3.post1=py310h06a4308_0
  - pyyaml=6.0.1=py310h2372a71_1
  - re2=2022.04.01=h295c915_0
  - readline=8.2=h5eee18b_0
  - referencing=0.30.2=py310h06a4308_0
  - regex=2023.10.3=py310h2372a71_0
  - requests=2.31.0=py310h06a4308_0
  - responses=0.13.3=pyhd3eb1b0_0
  - rfc3986=1.4.0=pyhd3eb1b0_0
  - rich=13.3.5=py310h06a4308_0
  - rpds-py=0.10.6=py310hb02cf49_0
  - sacremoses=0.0.43=pyhd3eb1b0_0
  - semantic_version=2.8.5=pyhd3eb1b0_0
  - setuptools=68.0.0=py310h06a4308_0
  - shellingham=1.5.0=py310h06a4308_0
  - six=1.16.0=pyhd3eb1b0_1
  - snappy=1.1.9=h295c915_0
  - sniffio=1.2.0=py310h06a4308_1
  - sqlite=3.41.2=h5eee18b_0
  - starlette=0.27.0=py310h06a4308_0
  - tbb=2021.8.0=hdb19cb5_0
  - tk=8.6.12=h1ccaba5_0
  - tokenizers=0.13.3=py310h22610ee_0
  - tomlkit=0.12.0=pyha770c72_0
  - toolz=0.12.0=py310h06a4308_0
  - torchaudio=0.13.1=py310_cu117
  - torchvision=0.14.1=py310_cu117
  - tqdm=4.66.1=pyhd8ed1ab_0
  - transformers=4.29.2=py310h06a4308_0
  - typer=0.9.0=py310h06a4308_0
  - typing-extensions=4.7.1=py310h06a4308_0
  - typing_extensions=4.7.1=py310h06a4308_0
  - tzdata=2023c=h04d1e81_0
  - urllib3=1.26.18=py310h06a4308_0
  - utf8proc=2.6.1=h27cfd23_0
  - uvicorn=0.20.0=py310h06a4308_0
  - wcwidth=0.2.12=pyhd8ed1ab_0
  - websockets=10.4=py310h5eee18b_1
  - wheel=0.41.2=py310h06a4308_0
  - xxhash=0.8.0=h7f8727e_3
  - xz=5.4.2=h5eee18b_0
  - yaml=0.2.5=h7f98852_2
  - yarl=1.8.1=py310h5eee18b_0
  - zipp=3.17.0=pyhd8ed1ab_0
  - zlib=1.2.13=h5eee18b_0
  - zstd=1.5.5=hc292b87_0
  - pip:
      - antlr4-python3-runtime==4.9.3
      - appdirs==1.4.4
      - docker-pycreds==0.4.0
      - facenet-pytorch==2.5.3
      - gitdb==4.0.11
      - gitpython==3.1.40
      - hydra-core==1.3.2
      - omegaconf==2.3.0
      - protobuf==4.25.1
      - sentry-sdk==1.37.1
      - setproctitle==1.3.3
      - smmap==5.0.1
      - snakeviz==2.2.0
      - tornado==6.3.3
      - wandb==0.16.0

Note that two minor adjustments to the code are necessary: https://github.com/mit-han-lab/fastcomposer/blob/5282a7e5f5a1640a26d7a1f87cd901f00f50966f/fastcomposer/model.py#L192C12-L192C12 has to be modified to

    def _build_causal_attention_mask(self, bsz, seq_len, dtype):
        # lazily create causal attention mask, with full attention between the vision tokens
        # pytorch uses additive attention mask; fill with -inf
        mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype)
        mask.fill_(torch.tensor(torch.finfo(dtype).min))
        mask.triu_(1)  # zero out the lower diagonal
        mask = mask.unsqueeze(1)  # expand mask
        return mask

and https://github.com/mit-han-lab/fastcomposer/blob/5282a7e5f5a1640a26d7a1f87cd901f00f50966f/fastcomposer/train.py#L57C38-L57C38 has to be changed to

        project_dir=args.logging_dir,
JoyHuYY1412 commented 9 months ago

Hi, thanks for sharing! I understand the second change of logging_dir, but why do we need to do the first change? what is the error with the original code?