Open linengcs opened 2 months ago
您好!我尝试在本地运行jittor_target = jittor.array(target, dtype=jittor.int64)
时并没有发现错误,我的测试代码如下:
import numpy as np
import jittor as jt
target = np.random.randint(-1, 13, (480, 480))
jittor_target = jt.array(target, dtype=jt.int64)
print(jittor_target.shape)
输出结果为:
[480,480,]
说明使用jt.array
并不是导致错误发生的原因。
从给出的报错信息
Traceback (most recent call last):
File "/home/ubuntu/hdd2/llf/miniconda3/envs/fdlnet_j/lib/python3.8/site-packages/jittor/dataset/dataset.py", line 258, in _worker_main
batch.append(self[i])
看出似乎是在编写dataset,从个人经历出发,我遇到过类似的情况,也即在重写Dataset类的
getitem`返回值如果是jittor类型必然报错,反倒是np.ndarray,原生的list,Image.Image类型都不会引发异常。
所以,如果是在处理dataset,可以尝试将这个转换为jittor类型的操作去除,直接返回numpy类型。
对应到你的代码,可以尝试如下操作:
def _mask_transform(self, mask):
# 将 PIL.Image 转换为 numpy 数组
mask_np = np.array(mask).astype('int32')
print(f"mask_np type: {type(mask_np)}, mask_np dtype: {mask_np.dtype}, mask_np shape: {mask_np.shape}")
print(f"mask_np min value: {mask_np.min()}, mask_np max value: {mask_np.max()}")
# 应用 _class_to_index 方法
target = self._class_to_index(mask_np)
print(
f"target after _class_to_index type: {type(target)}, target dtype: {target.dtype}, target shape: {target.shape}")
print(f"target min value: {target.min()}, target max value: {target.max()}")
# 将结果转换为 numpy 数组并设置数据类型
target = np.array(target).astype('int64')
print(f"target type: {type(target)}, target dtype: {target.dtype}, target shape: {target.shape}")
print(f"target min value: {target.min()}, target max value: {target.max()}")
# 检查异常值
if np.any(np.isnan(target)):
print("target contains NaN values")
if np.any(np.isinf(target)):
print("target contains infinite values")
return target
# ------------------------------不执行转换操作,直接返回numpy类型---------------------------------
# 尝试转换为 jittor.array
# try:
# jittor_target = jittor.array(target, dtype=jittor.int64)
# print(f"jittor_target type: {type(jittor_target)}, jittor_target dtype: {jittor_target.dtype}")
# except Exception as e:
# print(f"Error converting to Jittor array: {e}")
# import traceback
# traceback.print_exc()
# return jittor_target
谢谢哥,问题出在数据集的加载上,第一种方式是直接调用set_attrs方法:
self.train_dataloader = self.train_dataset.set_attrs(
batch_size=args.batch_size, shuffle=True, num_workers=args.workers
)
self.val_dataloader = self.val_dataset.set_attrs(
batch_size=args.batch_size, shuffle=False, num_workers=args.workers
)
这种方式就会出现上面的问题,但是如果采用间接调用set_attrs,即DataLoader函数,就不会出现该问题,具体的原因我也不清楚:
self.train_dataloader = DataLoader(self.train_dataset, batch_size=args.batch_size,shuffle=True, num_workers=args.workers)
self.val_dataloader = DataLoader(self.val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
对数据集的遍历方式如下:
for iteration, (images, targets, _) in enumerate(self.train_dataloader):
在执行下面函数里的
jittor_target = jittor.array(target, dtype=jittor.int64)
出现报错:报错信息
详细Log