Y-debug-sys / Diffusion-TS

[ICLR 2024] Official Implementation of "Diffusion-TS: Interpretable Diffusion for General Time Series Generation"
MIT License
187 stars 27 forks source link

关于生成数据评判标准 #72

Closed holydick99 closed 3 days ago

holydick99 commented 3 days ago

作者你好,在README文件中说到在Utils文件中有更多的评判指标计算,但是其中两个metric的文件中使用的是tensorflow的环境,这两个请问有pytorch下的实现吗?

Y-debug-sys commented 3 days ago

你好,为了与TimeGAN公平对比,直接使用了它的源码,所以没有。

holydick99 commented 3 days ago

你好,为了与TimeGAN公平对比,直接使用了它的源码,所以没有。

好的谢谢

holydick99 commented 3 days ago

import torch from Utils.cross_correlation import CrossCorrelLoss

x_real = torch.tensor([...]) # 替换为你的原始数据 x_fake = torch.tensor([...]) # 替换为你的生成数据

创建 CrossCorrelLoss 对象

cross_correl_loss = CrossCorrelLoss(x_real)

计算交叉相关系数的损失

loss = cross_correl_loss.forward(x_fake)

print(loss) 请问,这是正确的互相关计算代码吗