bxiang233 / ForAINet

official source code for paper entitled "Automated forest inventory: analysis of high-density airborne LiDAR point clouds with 3D deep learning"
36 stars 7 forks source link

An error in the Tree3DMix2 code #2

Open zqalex opened 4 months ago

zqalex commented 4 months ago

In PointCloudSegmentation/torch_points3d/core/data_transform/transforms.py, there is:

class Tree3DMix2(object): """ Tree3DMix2 prevents tree instances from overlapping """

def __init__(self):
    pass

def __call__(self, data, data2, stuff_classes):

However, when running the above code, the Compose.py calls transform with only one input "data", and does not provide data2 and stuff_classes, resulting in an error. How can this issue be resolved?

File "/ForAINet-main/PointCloudSegmentation/torch_points3d/datasets/panoptic/treeins_set1.py", line 512, in getitem data = super().getitem(idx) File "site-packages/torch_geometric/data/dataset.py", line 194, in getitem data = data if self.transform is None else self.transform(data) File "site-packages/torch_geometric/transforms/compose.py", line 15, in call data = transform(data) TypeError: call() missing 2 required positional arguments: 'data2' and 'stuff_classes'

bxiang233 commented 3 months ago

In PointCloudSegmentation/torch_points3d/core/data_transform/transforms.py, there is:

class Tree3DMix2(object): """ Tree3DMix2 prevents tree instances from overlapping """

def __init__(self):
    pass

def __call__(self, data, data2, stuff_classes):

However, when running the above code, the Compose.py calls transform with only one input "data", and does not provide data2 and stuff_classes, resulting in an error. How can this issue be resolved?

File "/ForAINet-main/PointCloudSegmentation/torch_points3d/datasets/panoptic/treeins_set1.py", line 512, in getitem data = super().getitem(idx) File "site-packages/torch_geometric/data/dataset.py", line 194, in getitem data = data if self.transform is None else self.transform(data) File "site-packages/torch_geometric/transforms/compose.py", line 15, in call data = transform(data) TypeError: call() missing 2 required positional arguments: 'data2' and 'stuff_classes'

Hi, thanks for remind me of this error!

please modify (For me it is on Line 178) in file

YOUR_CONDA_LOCATION#/envs/#YOUR_ENV_NAME#/lib/python3.8/site-packages/torch_geometric/data/dataset.py

to:

def __getitem__(
    self,
    idx: Union[int, np.integer, IndexType],
) -> Union['Dataset', Data]:
    r"""In case :obj:`idx` is of type integer, will return the data object
    at index :obj:`idx` (and transforms it in case :obj:`transform` is
    present).
    In case :obj:`idx` is a slicing object, *e.g.*, :obj:`[2:5]`, a list, a
    tuple, a PyTorch :obj:`LongTensor` or a :obj:`BoolTensor`, or a numpy
    :obj:`np.array`, will return a subset of the dataset at the specified
    indices."""
    if (isinstance(idx, (int, np.integer))
            or (isinstance(idx, Tensor) and idx.dim() == 0)
            or (isinstance(idx, np.ndarray) and np.isscalar(idx))):

        data = self.get(self.indices()[idx])
        #data = data if self.transform is None else self.transform(data)
        for transform in self.transform.transforms:
            if transform.__repr__() == "Mix3D":
                data2 = self.get(self.indices()[idx])
                data = transform(data, data2)
            elif transform.__repr__() == "Tree3DMix":
                data2 = self.get(self.indices()[idx])
                data = transform(data, data2, self.stuff_classes)
            else:
                data = transform(data)
        return data

    else:
        return self.index_select(idx)