Open KevinWang905 opened 2 years ago
@PSSF23 Can you assign me to this issue? Thanks
@nhahn7 I also save the test here in case the test file is changed in the future.
def test_update_task(self):
np.random.seed(1)
l2f = LifelongClassificationForest()
X = np.concatenate((np.zeros(100), np.ones(100))).reshape(-1, 1)
y = np.concatenate((np.zeros(100), np.ones(100)))
l2f.add_task(X, y)
u1 = l2f.predict_proba(np.array([0]).reshape(1, -1), task_id=0)
u2 = l2f.predict_proba(np.array([1]).reshape(1, -1), task_id=0)
X2 = np.concatenate((np.zeros(100), np.ones(100))).reshape(-1, 1)
y2 = np.concatenate((np.zeros(100), np.ones(100)))
X3 = np.concatenate((X, X2))
y3 = np.concatenate((y, y2))
l2f.update_task(X2, y2, task_id=0)
assert np.array_equiv(l2f.task_id_to_X[0], X3)
assert np.array_equiv(l2f.task_id_to_y[0], y3)
Part of #34: Add a function that updates previous tasks based on new data with task labels