RVC-Boss / GPT-SoVITS

1 min voice data can also be used to train a good TTS model! (few shot voice cloning)
MIT License
32.73k stars 3.77k forks source link

参考音频泄漏以及CPU推理 #516

Open XXXXRT666 opened 7 months ago

XXXXRT666 commented 7 months ago

我使用的是最新版本的GPT-Sovits项目与M2 8GB,前几个小时刚进行过git pull,我尝试将config.py 以及 infer_webui.py中的device判断删掉,强行使用device="cpu"进行推理,并将其速度与同版本的MPS推理进行比较。有如下发现:

  1. 发现CPU推理python全程内存占用3GB,内存曲线全程绿色,推理速度长时间保持55it/s,GPU推理python内存占用持续稳步上升至14GB,推理速度最高30it/s,时现1-2it/s

  2. 使用5s参考音频:“试了一局,发现效果也不错,后面就,会逐渐的去。” 使用8s参考音频:”我的小乔集锦和孙尚香集锦倒是挺多的,最近也在收集这个六破军吕布的素材,说不定能赶在李信集锦之前做出来 。“ 文本1:十三枪,861字 文本2:古诗,337字 均为默认四句一切 使用8s音频时CPU与GPU推理文本1,2均出现了较为严重的参考音频泄漏情况,使用5s音频时均正常。CPU推理与GPU推理结果无太大差别。

请大佬解惑

附上实验结果 最后的时间为推理用时

截屏2024-02-17 22 15 18

实验过程CPU GPU占用情况

截屏2024-02-17 20 38 34 截屏2024-02-17 20 44 52
XXXXRT666 commented 7 months ago

附上推理用的文本 十三枪.txt 古诗.txt

RVC-Boss commented 7 months ago

楼主是用的预训练的模型(没微调的)吗? 另外参考音频太长的话是会有这种现象,校对标注文本微调后会好很多, 另外可以试试无参考文本模式,如果实在介意漏了参考文本结尾的话

XXXXRT666 commented 7 months ago

不是,我GPT和Sovits模型都是用大概2个小时的人工标注文本微调过的,在换参考音频之后确实就没有泄漏了,只是用一个模型mac的CPU跑的比autodl上3090还快,觉得挺奇怪的,来问问。

我刚试了下另一个数据集更小的模型,也是相同结果。但这些模型均为GPT-Sovits一月中下旬版本炼制,同模型在4090上推理速度为90it/s,若mac cpu推理速度更快为普遍现象,说不定可以在判断中将device设置成cpu进行加速。

Lion-Wu commented 7 months ago

感谢指出,经过测试后,我发现在 Mac 上使用 CPU 推理确实比 GPU 快很多。 所以我觉得使用 CPU 推理确实更好,既提高了速度,还没有 mps 的内存泄漏问题。 我稍后尝试提交 PR

XXXXRT666 commented 7 months ago

请问大佬对于为何M系列芯片CPU推理更快是否有想法?我测试过Amd和Intel的CPU推理,其效果均不如Mac的Cpu,推理速度远低于MPS推理

Lion-Wu commented 7 months ago

也许因为使用统一内存,延迟很低,而且内存带宽远高于其他电脑

v3ucn commented 7 months ago

直接强改cpu太武断了,是sonoma的问题,系统对pytorch的计算支持有bug,你们测试的os应该都是sonoma,随着系统更新,这个bug肯定会被修复

Lion-Wu commented 7 months ago

直接强改cpu太武断了,是sonoma的问题,系统对pytorch的计算支持有bug,你们测试的os应该都是sonoma,随着系统更新,这个bug肯定会被修复

不会吧,我刚试了一下Ventura,也是CPU明显快于GPU啊? Ventura下GPU速度时快时慢,范围在5it/s~35it/s波动,推理一段时间后速度还会突然降到一秒不到1个it(两个版本都这样),速度非常不稳定; CPU速度稳定在40左右,而且还没有GPU的内存泄漏问题,综合来讲我觉得应该是CPU会更好。 你提到的bug能提供具体信息吗?或者你测试是怎样?

v3ucn commented 7 months ago

https://github.com/pytorch/pytorch/issues/111517 参考帖子,修改cmakes的编译方式

cpu推理:

['zh']
 19%|███████▍                                | 280/1500 [00:12<00:47, 25.55it/s]T2S Decoding EOS [102 -> 382]
 19%|███████▍                                | 280/1500 [00:12<00:56, 21.54it/s]

gpu推理:

 21%|████████▌                               | 322/1500 [00:08<00:32, 36.46it/s]T2S Decoding EOS [102 -> 426]
 22%|████████▋                               | 324/1500 [00:08<00:29, 39.26it/s]

但确实有内存泄露的现象

Lion-Wu commented 7 months ago

想确认一下你是将 config.py 和 infer_webui.py 中的 device 都改成cpu了吗?另外你是什么型号什么版本?

XXXXRT666 commented 7 months ago

是的,我刚实验时发现了个问题,会出现半精度错误,好像M系列芯片is_half判断结果为True,所以还要pr把is_half改一下 我是M2 8G macos14.3.1

XXXXRT666 commented 7 months ago

将inference_webui.py中is_half = eval(os.environ.get("is_half", "True"))->is_half = eval(os.environ.get("is_half", "True")) and not torch.backends.mps.is_available()并将import torch提前可解决,config中由于当设定为cpu推理时is_half自动为False,无需修改

v3ucn commented 7 months ago

只改了infer_webui.py ,有必要改config吗?看了一下好像只有api会调用config的里面的device, 另外我的版本是python3.11 torch版本如下,不知道有没有影响,就是说我的版本都弄成了最新的 Successfully installed torch-2.3.0.dev20240221 torchaudio-2.2.0.dev20240221

XXXXRT666 commented 7 months ago

if(infer_device=="cpu"):is_half=False config.py较末尾处还有此判断

Lion-Wu commented 7 months ago

将inference_webui.py中is_half = eval(os.environ.get("is_half", "True"))->is_half = eval(os.environ.get("is_half", "True")) and not torch.backends.mps.is_available()并将import torch提前可解决,config中由于当设定为cpu推理时is_half自动为False,无需修改

macOS 14 以上以及最新pytorch nightly build以及支持BF16和float16了,叫GPT写了如下代码测试:

import torch

# 确保 PyTorch 可以访问 MPS
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"使用设备: {device}")

# 测试 float16
try:
    print("\n测试 float16...")
    a = torch.randn(3, 3, dtype=torch.float32, device=device)
    a = a.to(torch.float16)
    print(f"float16 计算结果: \n{a + a}")
    print("float16 测试成功,没有错误。")
except Exception as e:
    print(f"float16 测试失败: {e}")

# 测试 bf16 (请注意,这可能在某些 PyTorch 版本或设备上不受支持)
try:
    print("\n测试 bf16...")
    b = torch.randn(3, 3, dtype=torch.float32, device=device)
    b = b.to(torch.bfloat16)
    print(f"bf16 计算结果: \n{b + b}")
    print("bf16 测试成功,没有错误。")
except Exception as e:
    print(f"bf16 测试失败: {e}")

Ventura下提示需要macOS 14以上系统,Sonoma下输出:

使用设备: mps
测试 float16...
float16 计算结果: 
tensor([[ 1.7812, -0.8813,  0.7148],
        [ 3.2305, -3.3828, -0.6938],
        [-0.1914, -0.3582,  1.5400]], device='mps:0', dtype=torch.float16)
float16 测试成功,没有错误。

测试 bf16...
bf16 计算结果: 
tensor([[0.3789, 0.7344, 0.4902],
        [0.5508, 1.5078, 1.3906],
        [3.5781, 3.7656, 0.7383]], device='mps:0', dtype=torch.bfloat16)
bf16 测试成功,没有错误。

所以应该不用修改。

Lion-Wu commented 7 months ago

只改了infer_webui.py ,有必要改config吗?看了一下好像只有api会调用config的里面的device, 另外我的版本是python3.11 torch版本如下,不知道有没有影响,就是说我的版本都弄成了最新的 Successfully installed torch-2.3.0.dev20240221 torchaudio-2.2.0.dev20240221

我试了下把config.py改回mps,速度确实慢了……你改一下试试? 另外我记得README写着:Note: numba==0.56.4 require py<3.11,你这是装了其他版本吗?我使用的是3.9

XXXXRT666 commented 7 months ago

Sonoma得到类似结果,估计我的torch版本不支持bf16


使用设备: mps

测试 float16...
float16 计算结果: 
tensor([[ 1.9229,  2.0703, -1.1650],
        [-0.8335,  1.1064, -0.4231],
        [ 0.8765,  0.9316, -0.5986]], device='mps:0', dtype=torch.float16)
float16 测试成功,没有错误。

测试 bf16...
bf16 测试失败: BFloat16 is not supported on MPS
XXXXRT666 commented 7 months ago

只改了infer_webui.py ,有必要改config吗?看了一下好像只有api会调用config的里面的device, 另外我的版本是python3.11 torch版本如下,不知道有没有影响,就是说我的版本都弄成了最新的 Successfully installed torch-2.3.0.dev20240221 torchaudio-2.2.0.dev20240221

我试了下把config.py改回mps,速度确实慢了……你改一下试试? 另外我记得README写着:Note: numba==0.56.4 require py<3.11,你这是装了其他版本吗?我使用的是3.9

我尝试的结果是mps确实挺慢的,如果将is_half写死为False就更慢了,将is_half写死后推理出现3-5it/s的频率增加

Lion-Wu commented 7 months ago

Sonoma得到类似结果,估计我的torch版本不支持bf16


使用设备: mps

测试 float16...
float16 计算结果: 
tensor([[ 1.9229,  2.0703, -1.1650],
        [-0.8335,  1.1064, -0.4231],
        [ 0.8765,  0.9316, -0.5986]], device='mps:0', dtype=torch.float16)
float16 测试成功,没有错误。

测试 bf16...
bf16 测试失败: BFloat16 is not supported on MPS

你有装nightly build吗? 不得是装nightly build才可以训练模型嘛吗我记得,你这个看起来好像BF16也得nightly build

XXXXRT666 commented 7 months ago

Sonoma得到类似结果,估计我的torch版本不支持bf16


使用设备: mps

测试 float16...
float16 计算结果: 
tensor([[ 1.9229,  2.0703, -1.1650],
        [-0.8335,  1.1064, -0.4231],
        [ 0.8765,  0.9316, -0.5986]], device='mps:0', dtype=torch.float16)
float16 测试成功,没有错误。

测试 bf16...
bf16 测试失败: BFloat16 is not supported on MPS

你有装nightly build吗? 不得是装nightly build才可以训练模型嘛吗我记得,你这个看起来好像BF16也得nightly build

我训练模型直接上云的,本地跑估计快不到哪里去

XXXXRT666 commented 7 months ago

刚开了一下本地训练,发现无报错,训练不了,不知道是否是我内存不够还是torch版本问题

v3ucn commented 7 months ago

是用的python3.11,所有依赖如下:

(base) ➜  ~ pip list
Package                       Version
----------------------------- -----------------
addict                        2.4.0
aiobotocore                   2.5.0
aiofiles                      22.1.0
aiohttp                       3.8.5
aioitertools                  0.7.1
aiosignal                     1.2.0
aiosqlite                     0.18.0
alabaster                     0.7.12
aliyun-python-sdk-core        2.14.0
aliyun-python-sdk-kms         2.16.2
altair                        5.2.0
anaconda-anon-usage           0.4.2
anaconda-catalogs             0.2.0
anaconda-client               1.12.1
anaconda-cloud-auth           0.1.3
anaconda-navigator            2.5.0
anaconda-project              0.11.1
annotated-types               0.6.0
anyio                         3.5.0
appdirs                       1.4.4
applaunchservices             0.3.0
appnope                       0.1.2
appscript                     1.1.2
argon2-cffi                   21.3.0
argon2-cffi-bindings          21.2.0
arrow                         1.2.3
astroid                       2.14.2
astropy                       5.1
asttokens                     2.0.5
async-timeout                 4.0.2
atomicwrites                  1.4.0
attrs                         22.1.0
audioread                     3.0.1
Automat                       20.2.0
autopep8                      1.6.0
av                            11.0.0
Babel                         2.11.0
backcall                      0.2.0
backports.functools-lru-cache 1.6.4
backports.tempfile            1.0
backports.weakref             1.0.post1
bcrypt                        3.2.0
beautifulsoup4                4.12.2
binaryornot                   0.4.4
black                         0.0
bleach                        4.1.0
bokeh                         3.2.1
boltons                       23.0.0
botocore                      1.29.76
Bottleneck                    1.3.5
brotlipy                      0.7.0
certifi                       2023.7.22
cffi                          1.15.1
chardet                       4.0.0
charset-normalizer            2.0.4
click                         8.0.4
cloudpickle                   2.2.1
clyent                        1.2.2
cn2an                         0.5.22
colorama                      0.4.6
colorcet                      3.0.1
coloredlogs                   15.0.1
comm                          0.1.2
conda                         23.7.4
conda-build                   3.26.1
conda-content-trust           0.2.0
conda_index                   0.3.0
conda-libmamba-solver         23.7.0
conda-pack                    0.6.0
conda-package-handling        2.2.0
conda_package_streaming       0.9.0
conda-repo-cli                1.0.75
conda-token                   0.4.0
conda-verify                  3.4.2
constantly                    15.1.0
contourpy                     1.0.5
cookiecutter                  1.7.3
crcmod                        1.7
cryptography                  41.0.3
cssselect                     1.1.0
ctranslate2                   4.0.0
cycler                        0.11.0
cytoolz                       0.12.0
dask                          2023.6.0
datasets                      2.17.1
datashader                    0.15.2
datashape                     0.5.4
debugpy                       1.6.7
decorator                     5.1.1
defusedxml                    0.7.1
diff-match-patch              20200713
dill                          0.3.6
Distance                      0.1.3
distributed                   2023.6.0
docstring-to-markdown         0.11
docutils                      0.18.1
einops                        0.7.0
entrypoints                   0.4
et-xmlfile                    1.1.0
executing                     0.8.3
fastapi                       0.109.2
faster-whisper                0.10.0
fastjsonschema                2.16.2
ffmpeg-python                 0.2.0
ffmpy                         0.3.2
filelock                      3.9.0
flake8                        6.0.0
Flask                         2.2.2
flatbuffers                   23.5.26
fonttools                     4.25.0
frozenlist                    1.3.3
fsspec                        2023.10.0
future                        0.18.3
g2p-en                        2.1.0
gast                          0.5.4
gensim                        4.3.0
glob2                         0.7
gmpy2                         2.1.2
gradio                        3.38.0
gradio_client                 0.7.0
greenlet                      2.0.1
h11                           0.14.0
h5py                          3.9.0
HeapDict                      1.0.1
holoviews                     1.17.1
httpcore                      1.0.3
httpx                         0.26.0
huggingface-hub               0.20.3
humanfriendly                 10.0
hvplot                        0.8.4
hyperlink                     21.0.0
idna                          3.4
imagecodecs                   2023.1.23
imageio                       2.31.1
imagesize                     1.4.1
imbalanced-learn              0.10.1
importlib-metadata            6.0.0
importlib-resources           6.1.1
incremental                   21.3.0
inflect                       7.0.0
inflection                    0.5.1
iniconfig                     1.1.1
intake                        0.6.8
intervaltree                  3.1.0
ipykernel                     6.25.0
ipython                       8.15.0
ipython-genutils              0.2.0
ipywidgets                    8.0.4
isort                         5.9.3
itemadapter                   0.3.0
itemloaders                   1.0.4
itsdangerous                  2.0.1
jaraco.classes                3.2.1
jedi                          0.18.1
jellyfish                     1.0.1
jieba-fast                    0.53
Jinja2                        3.1.2
jinja2-time                   0.2.0
jmespath                      0.10.0
joblib                        1.2.0
json5                         0.9.6
jsonpatch                     1.32
jsonpointer                   2.1
jsonschema                    4.17.3
jupyter                       1.0.0
jupyter_client                7.4.9
jupyter-console               6.6.3
jupyter_core                  5.3.0
jupyter-events                0.6.3
jupyter-server                1.23.4
jupyter_server_fileid         0.9.0
jupyter_server_ydoc           0.8.0
jupyter-ydoc                  0.2.4
jupyterlab                    3.6.3
jupyterlab-pygments           0.1.2
jupyterlab_server             2.22.0
jupyterlab-widgets            3.0.5
kaleido                       0.2.1
keyring                       23.13.1
kiwisolver                    1.4.4
lazy_loader                   0.2
lazy-object-proxy             1.6.0
libarchive-c                  2.9
libmambapy                    1.5.1
librosa                       0.10.1
lightning-utilities           0.10.1
linkify-it-py                 2.0.0
llvmlite                      0.40.0
lmdb                          1.4.1
locket                        1.0.0
lxml                          4.9.3
lz4                           4.3.2
Markdown                      3.4.1
markdown-it-py                2.2.0
MarkupSafe                    2.1.1
matplotlib                    3.7.2
matplotlib-inline             0.1.6
mccabe                        0.7.0
mdit-py-plugins               0.3.0
mdurl                         0.1.0
mistune                       0.8.4
modelscope                    1.12.0
more-itertools                8.12.0
mpmath                        1.3.0
msgpack                       1.0.3
multidict                     6.0.2
multipledispatch              0.6.0
multiprocess                  0.70.14
munkres                       1.1.4
mypy-extensions               1.0.0
navigator-updater             0.4.0
nbclassic                     0.5.5
nbclient                      0.5.13
nbconvert                     6.5.4
nbformat                      5.9.2
nest-asyncio                  1.5.6
networkx                      3.1
nltk                          3.8.1
notebook                      6.5.4
notebook_shim                 0.2.2
numba                         0.57.1
numexpr                       2.8.4
numpy                         1.24.3
numpydoc                      1.5.0
onnxruntime                   1.17.0
openpyxl                      3.0.10
orjson                        3.9.14
oss2                          2.18.4
packaging                     23.1
pandas                        2.0.3
pandocfilters                 1.5.0
panel                         1.2.3
param                         1.13.0
parsel                        1.6.0
parso                         0.8.3
partd                         1.4.0
pathlib                       1.0.1
pathspec                      0.10.3
patsy                         0.5.3
pep8                          1.7.1
pexpect                       4.8.0
pickleshare                   0.7.5
Pillow                        9.4.0
pip                           23.2.1
pkce                          1.0.3
pkginfo                       1.9.6
platformdirs                  3.10.0
plotly                        5.9.0
pluggy                        1.0.0
ply                           3.11
pooch                         1.8.1
poyo                          0.5.0
proces                        0.1.7
prometheus-client             0.14.1
prompt-toolkit                3.0.36
Protego                       0.1.16
protobuf                      4.25.3
psutil                        5.9.0
ptyprocess                    0.7.0
pure-eval                     0.2.2
py-cpuinfo                    8.0.0
pyarrow                       15.0.0
pyarrow-hotfix                0.6
pyasn1                        0.4.8
pyasn1-modules                0.2.8
pycodestyle                   2.10.0
pycosat                       0.6.4
pycparser                     2.21
pycryptodome                  3.20.0
pyct                          0.5.0
pycurl                        7.45.2
pydantic                      2.6.1
pydantic_core                 2.16.2
PyDispatcher                  2.0.5
pydocstyle                    6.3.0
pydub                         0.25.1
pyerfa                        2.0.0
pyflakes                      3.0.1
Pygments                      2.15.1
PyJWT                         2.4.0
pylint                        2.16.2
pylint-venv                   2.3.0
pyls-spyder                   0.4.0
pyobjc-core                   9.0
pyobjc-framework-Cocoa        9.0
pyobjc-framework-CoreServices 9.0
pyobjc-framework-FSEvents     9.0
pyodbc                        4.0.34
pyopenjtalk                   0.3.3
pyOpenSSL                     23.2.0
pyparsing                     3.0.9
pypinyin                      0.50.0
PyQt5-sip                     12.11.0
pyrsistent                    0.18.0
PySocks                       1.7.1
pytest                        7.4.0
python-dateutil               2.8.2
python-dotenv                 0.21.0
python-json-logger            2.0.7
python-lsp-black              1.2.1
python-lsp-jsonrpc            1.0.0
python-lsp-server             1.7.2
python-multipart              0.0.9
python-slugify                5.0.2
python-snappy                 0.6.1
pytoolconfig                  1.2.5
pytorch-lightning             2.2.0.post0
pytz                          2023.3.post1
pyviz-comms                   2.3.0
PyWavelets                    1.4.1
PyYAML                        6.0
pyzmq                         23.2.0
QDarkStyle                    3.0.2
qstylizer                     0.2.2
QtAwesome                     1.2.2
qtconsole                     5.4.2
QtPy                          2.2.0
queuelib                      1.5.0
regex                         2022.7.9
requests                      2.31.0
requests-file                 1.5.1
requests-toolbelt             1.0.0
responses                     0.13.3
rfc3339-validator             0.1.4
rfc3986-validator             0.1.1
rich                          13.7.0
rope                          1.7.0
Rtree                         1.0.1
ruamel.yaml                   0.17.21
ruamel-yaml-conda             0.17.21
s3fs                          2023.4.0
safetensors                   0.3.2
scikit-image                  0.20.0
scikit-learn                  1.3.0
scipy                         1.11.1
Scrapy                        2.8.0
seaborn                       0.12.2
semantic-version              2.10.0
Send2Trash                    1.8.0
service-identity              18.1.0
setuptools                    68.0.0
shellingham                   1.5.4
simplejson                    3.19.2
sip                           6.6.2
six                           1.16.0
smart-open                    5.2.1
sniffio                       1.2.0
snowballstemmer               2.2.0
sortedcontainers              2.4.0
soundfile                     0.12.1
soupsieve                     2.4
soxr                          0.3.7
Sphinx                        5.0.2
sphinxcontrib-applehelp       1.0.2
sphinxcontrib-devhelp         1.0.2
sphinxcontrib-htmlhelp        2.0.0
sphinxcontrib-jsmath          1.0.1
sphinxcontrib-qthelp          1.0.3
sphinxcontrib-serializinghtml 1.1.5
spyder                        5.4.3
spyder-kernels                2.4.4
SQLAlchemy                    1.4.39
stack-data                    0.2.0
starlette                     0.36.3
statsmodels                   0.14.0
sympy                         1.11.1
tables                        3.8.0
tabulate                      0.8.10
tblib                         1.7.0
tenacity                      8.2.2
terminado                     0.17.1
text-unidecode                1.3
textdistance                  4.2.1
threadpoolctl                 2.2.0
three-merge                   0.1.1
tifffile                      2023.4.12
tinycss2                      1.2.1
tldextract                    3.2.0
tokenizers                    0.13.2
toml                          0.10.2
tomlkit                       0.12.0
toolz                         0.12.0
torch                         2.3.0.dev20240221
torchaudio                    2.2.0.dev20240221
torchmetrics                  1.3.1
tornado                       6.3.2
tqdm                          4.65.0
traitlets                     5.7.1
transformers                  4.32.1
Twisted                       22.10.0
typer                         0.9.0
typing_extensions             4.9.0
tzdata                        2023.3
uc-micro-py                   1.0.1
ujson                         5.4.0
Unidecode                     1.2.0
urllib3                       1.26.16
uvicorn                       0.27.1
w3lib                         1.21.0
watchdog                      2.1.6
wcwidth                       0.2.5
webencodings                  0.5.1
websocket-client              0.58.0
websockets                    11.0.3
Werkzeug                      2.2.3
whatthepatch                  1.0.2
wheel                         0.38.4
widgetsnbextension            4.0.5
wrapt                         1.14.1
wurlitzer                     3.0.2
xarray                        2023.6.0
xlwings                       0.29.1
xxhash                        2.0.2
xyzservices                   2022.9.0
y-py                          0.5.9
yapf                          0.31.0
yarl                          1.8.1
ypy-websocket                 0.8.2
zict                          2.2.0
zipp                          3.11.0
zope.interface                5.4.0
zstandard                     0.19.0

只改了推理页面的device,就是强制"mps:0",可以跑,没有任何问题

Lion-Wu commented 7 months ago

Sonoma得到类似结果,估计我的torch版本不支持bf16


使用设备: mps

测试 float16...
float16 计算结果: 
tensor([[ 1.9229,  2.0703, -1.1650],
        [-0.8335,  1.1064, -0.4231],
        [ 0.8765,  0.9316, -0.5986]], device='mps:0', dtype=torch.float16)
float16 测试成功,没有错误。

测试 bf16...
bf16 测试失败: BFloat16 is not supported on MPS

你有装nightly build吗? 不得是装nightly build才可以训练模型嘛吗我记得,你这个看起来好像BF16也得nightly build

我训练模型直接上云的,本地跑估计快不到哪里去

本地还好,我M1 Max用一分钟的样本大概十分钟左右搞定,但是要是一个小时的样本,我发现有比较严重的内存泄漏,导致速度变慢。考虑到你是M2,并且只有8GB内存,可能确实不适合,我32GB的都是用了交换内存来训练…… CPU训练我之前尝试过,比mps慢一些,但还是可以接受的,还不会有内存泄漏,我甚至感觉这种情况用CPU训练更好😂

Lion-Wu commented 7 months ago

刚开了一下本地训练,发现无报错,训练不了,不知道是否是我内存不够还是torch版本问题

理论上不会吧,就算内存不足也会用交换内存,你可能没有装nightly版本,按照README里Mac的安装方法试试

Lion-Wu commented 7 months ago

@v3ucn 你试试两个都改成CPU,看看速度怎样

XXXXRT666 commented 7 months ago

Sonoma得到类似结果,估计我的torch版本不支持bf16


使用设备: mps

测试 float16...
float16 计算结果: 
tensor([[ 1.9229,  2.0703, -1.1650],
        [-0.8335,  1.1064, -0.4231],
        [ 0.8765,  0.9316, -0.5986]], device='mps:0', dtype=torch.float16)
float16 测试成功,没有错误。

测试 bf16...
bf16 测试失败: BFloat16 is not supported on MPS

你有装nightly build吗? 不得是装nightly build才可以训练模型嘛吗我记得,你这个看起来好像BF16也得nightly build

我训练模型直接上云的,本地跑估计快不到哪里去

本地还好,我M1 Max用一分钟的样本大概十分钟左右搞定,但是要是一个小时的样本,我发现有比较严重的内存泄漏,导致速度变慢。考虑到你是M2,并且只有8GB内存,可能确实不适合,我32GB的都是用了交换内存来训练…… CPU训练我之前尝试过,比mps慢一些,但还是可以接受的,还不会有内存泄漏,我甚至感觉这种情况用CPU训练更好😂

我用的素材集有三个多小时,本地训练基本无缘的

Lion-Wu commented 7 months ago

Sonoma得到类似结果,估计我的torch版本不支持bf16


使用设备: mps

测试 float16...
float16 计算结果: 
tensor([[ 1.9229,  2.0703, -1.1650],
        [-0.8335,  1.1064, -0.4231],
        [ 0.8765,  0.9316, -0.5986]], device='mps:0', dtype=torch.float16)
float16 测试成功,没有错误。

测试 bf16...
bf16 测试失败: BFloat16 is not supported on MPS

你有装nightly build吗? 不得是装nightly build才可以训练模型嘛吗我记得,你这个看起来好像BF16也得nightly build

我训练模型直接上云的,本地跑估计快不到哪里去

本地还好,我M1 Max用一分钟的样本大概十分钟左右搞定,但是要是一个小时的样本,我发现有比较严重的内存泄漏,导致速度变慢。考虑到你是M2,并且只有8GB内存,可能确实不适合,我32GB的都是用了交换内存来训练…… CPU训练我之前尝试过,比mps慢一些,但还是可以接受的,还不会有内存泄漏,我甚至感觉这种情况用CPU训练更好😂

我用的素材集有三个多小时,本地训练基本无缘的

那确实,不过要是都改成CPU而且愿意等的话说不定还真可以😂

XXXXRT666 commented 7 months ago

我去试试看 依据#290 的改法能不能跑