Closed lsz05 closed 3 months ago
N/A
STSにscipyの関数を使っており,その内部にtensor.numpy()の処理があり,tensorがtorch.bfloat16タイプの場合,不具合が起こります。
scipy
tensor.numpy()
tensor
torch.bfloat16
[rank2]: File "~/JMTEB/src/jmteb/evaluators/sts/evaluator.py", line 74, in __call__ [rank2]: val_results[sim_name], _ = self._compute_similarity( [rank2]: File "~/JMTEB/src/jmteb/evaluators/sts/evaluator.py", line 112, in _compute_similarity [rank2]: pearson = pearsonr(golden_scores, sim_scores)[0] [rank2]: File "/usr/local/lib/python3.10/dist-packages/scipy/stats/_stats_py.py", line 4727, in pearsonr [rank2]: y = np.asarray(y) [rank2]: File "/usr/local/lib/python3.10/dist-packages/torch/_tensor.py", line 1087, in __array__ [rank2]: return self.numpy() [rank2]: TypeError: Got unsupported ScalarType BFloat16
sim_scoresがtorch.bfloat16である場合,.float()のdtype変更
sim_scores
.float()
dtype
LGTM
関連する Issue / PR
N/A
PR をマージした後の挙動の変化
STSに
scipy
の関数を使っており,その内部にtensor.numpy()
の処理があり,tensor
がtorch.bfloat16
タイプの場合,不具合が起こります。挙動の変更を達成するために行ったこと
sim_scores
がtorch.bfloat16
である場合,.float()
のdtype
変更動作確認