Closed aditya0by0 closed 2 weeks ago
RaggedCollator
is failing!There is a potential misalignment issue in the RaggedCollator
class when processing data where some labels are None
. Currently, the code correctly omits None
labels from the y
list but does not simultaneously remove the corresponding features from the x
list. This causes a misalignment between features and labels, leading to incorrect training or evaluation outcomes.
tests/unit/collators/testRaggedCollator.test_call_with_missing_entire_labels
Test Case
Currently, this test fails because the feature corresponding to the None
label is not omitted, causing a misalignment in the result.x
and result.y
.
Please let me know if this test case is relevant and correctly aligned with the purpose of the RaggedCollator
class. Additionally, confirm if the expected results in the test case are appropriate and consistent with the class's intended functionality.
To fix the issue, the features (x
) should also be filtered based on the non_null_labels
index, ensuring that x
and y
remain aligned.
Here's the corrected portion of the code:
non_null_labels = [i for i, r in enumerate(y) if r is not None]
y = self.process_label_rows(
tuple(ye for i, ye in enumerate(y) if i in non_null_labels)
)
x = [xe for i, xe in enumerate(x) if i in non_null_labels] # Filter x based on non_null_labels
loss_kwargs["non_null_labels"] = non_null_labels
This ensures that both x
and y
contain only the valid (non-None
) entries and that they remain properly aligned.
There is a potential misalignment issue in the
RaggedCollator
class when processing data where some labels areNone
. Currently, the code correctly omitsNone
labels from they
list but does not simultaneously remove the corresponding features from thex
list. This causes a misalignment between features and labels, leading to incorrect training or evaluation outcomes.
This is intended behaviour. In some training examples, we use a mixture of labelled and unlabelled data in combination with certain loss functions that allow for partially unlabelled data (e.g. fuzzy loss). In order to compute the usual metrics (F1, MSE etc), one needs to filter the predictions for unlabelled data and only compute them on labelled data. The indices of these data points are stored in the ' non_null_labeles' field and used by our implementations of Electra and MixedLoss.
Therefore, the shape of y
should only align with x
modulo non_null_labels
.
term_callback
A test case for term_callback
is failing because it is not correctly ignoring/skipping obsolete ChEBI terms. As a result, the test cases for _extract_class_hierarchy
and _graph_to_raw_dataset
are also failing as output of term_callback
are used by them.
Current Behavior:
_graph_to_raw_dataset
method filters out data instances:
data = data[~data["SMILES"].isnull()]
data = data[data.iloc[:, self._LABELS_START_IDX:].any(axis=1)]
So, even though obsolete terms are not specifically filtered, their lack of SMILES strings ensures they are excluded from the dataset.
Potential Future Issue:
Example of a Problematic Obsolete Term:
[Term]
id: CHEBI:77533
name: Compound G
is_a: CHEBI:99999
property_value: http://purl.obolibrary.org/obo/chebi/smiles "C1=C1Br" xsd:string
is_obsolete: true
If terms like this exist in future releases, the current approach could lead to errors because obsolete terms with SMILES strings might slip through the filters.
Proposed Solution:
We can update the term_callback
logic to explicitly ignore obsolete terms by checking for the is_obsolete
clause:
if isinstance(clause, fastobo.term.IsObsoleteClause):
if clause.obsolete:
# If the term document contains an "obsolete: true" clause, skip this term.
return False
This solution would ensure that obsolete terms are skipped before they are processed, preventing potential future issues with the dataset.
Tox21MolNet
:I've encountered an issue with the setup_processed
method when working with the Tox21MolNet
and its data (tox21.csv
file). It appears that the file does not include a header or key named "group"
, which is causing a KeyError
in the line:
groups = np.array([d["group"] for d in data])
Additionally, the _load_data_from_file
method does not seem to utilize the any Reader
to create or handle a "group"
key in the data. As a result, the group
key does not exist in the dictionaries produced by _load_data_from_file
, leading to the observed error.
The _load_data_from_file
method only yields three keys: features
, labels
, and ident
:
yield dict(features=smiles, labels=labels, ident=row["mol_id"])
Please let me know for your suggestions on this issue.
As discussed, here are some additional test cases (I also added them at the top):
ChEBIOverXPartial
: should cover one label scenario from PR #54 DynamicDataset
: Check for the data splits if their are stratifiedsetup_processed
tests: should also check if the output has a structure that can be read by the collator (e.g., features
should be tensor-able) -> expected to fail before #56 is resolved
- Readers: Should also check if the "real" token order (as defined by tokens.txt) stays consistent
To ensure the token order in the "real" tokens.txt
file remains consistent, we can maintain a corresponding duplicate tokens.txt
file in the test directory. This duplicate file will serve as the reference for validating the order of tokens in the actual tokens.txt
. During testing, we will compare the contents of the real file against this reference to check for consistency in both content and order.
Alternatively, we could verify the token order before and after any token insertion to ensure order consistency without the need for a duplicate file. However, this approach would be vulnerable to manual or direct changes in the tokens.txt
file, which may not be detected.
Please let me know if you have any suggestions or alternative approaches to this method.
@sfluegel05, can you please provide your suggestion/input on the respective comment.
- Readers: Should also check if the "real" token order (as defined by tokens.txt) stays consistent
To ensure the token order in the "real"
tokens.txt
file remains consistent, we can maintain a corresponding duplicatetokens.txt
file in the test directory. This duplicate file will serve as the reference for validating the order of tokens in the actualtokens.txt
. During testing, we will compare the contents of the real file against this reference to check for consistency in both content and order.Alternatively, we could verify the token order before and after any token insertion to ensure order consistency without the need for a duplicate file. However, this approach would be vulnerable to manual or direct changes in the
tokens.txt
file, which may not be detected.Please let me know if you have any suggestions or alternative approaches to this method.
I have added the test for protein pretraining. Now all the unit tests are working. Please review and merge.
Do you think it would be appropriate to include the unit tests related to Tox21MolNet in the same pull request or issue that addresses its rectification, specifically PR #56?
Thanks for finishing this. I removed the link to the unit test issue since we still have the toxicity-related unit tests which are not included in this PR.
I agree. I added a note for that in #56
Issue #45
Dependency :
Unit Testing Checklist
reader.py
to_data()
with sample input values._read_data()
with sample SMILES strings._read_data()
with sample input values._read_data()
with sample SELFIES strings._read_data()
with sample protein sequences.collate.py
__call__()
with sample data.__call__()
with sample data.process_label_rows()
with sample data.datasets/base.py
_filter_labels()
with sample input values.get_test_split()
with sample data.get_train_val_splits_given_test()
with sample data.datasets/chebi.py
_extract_class_hierarchy()
with mock data._graph_to_raw_dataset()
with mock data._load_dict()
with mock data._setup_pruned_test_set()
with mock data.select_classes()
with sample data.extract_class_hierarchy()
with mock data.term_callback
term_callback()
with sample data.datasets/go_uniprot.py
_extract_class_hierarchy()
with mock data.term_callback()
with sample data._graph_to_raw_dataset()
with mock data._get_swiss_to_go_mapping()
with mock data._load_dict()
with mock data.select_classes()
with sample data.datasets/tox21.py
setup_processed()
with mock data._load_data_from_file()
using mock file operations._load_dict()
with mock data.datasets/protein_pretraining.py
_parse_protein_data_for_pretraining()
with mock data.Note: Tests for Tox21MolNet will be added later in seperate PR/branch after completion of the issue #53