nnaisense / evotorch

Advanced evolutionary computation library built directly on top of PyTorch, created at NNAISENSE.
https://evotorch.ai
Apache License 2.0
997 stars 62 forks source link

problem._get_best_and_worst #84

Closed haohaomiao closed 1 year ago

haohaomiao commented 1 year ago

Is there something wrong with the implementation? It seems that when len(senses)==1 and self._best is None, only one of _best and _worst will be set, while the other will remain None. This could cause an error because both _best and _worst will be accessed when the function returns.

    @torch.no_grad()
    def _get_best_and_worst(self, batch: "SolutionBatch") -> Optional[dict]:
        if self._store_solution_stats is None:
            self._store_solution_stats = str(batch.device) == "cpu"

        if not self._store_solution_stats:
            return {}

        senses = self.senses
        nobjs = len(senses)

        if self._best is None:
            self._best_evals = self.make_empty(nobjs, device=batch.device, use_eval_dtype=True)
            self._worst_evals = self.make_empty(nobjs, device=batch.device, use_eval_dtype=True)
            for i_obj in range(nobjs):
                if senses[i_obj] == "min":
                    self._best_evals[i_obj] = float("inf")
                    self._worst_evals[i_obj] = float("-inf")
                elif senses[i_obj] == "max":
                    self._best_evals[i_obj] = float("-inf")
                    self._worst_evals[i_obj] = float("inf")
                else:
                    raise ValueError(f"Invalid sense: {senses[i_obj]}")
            self._best = [None] * nobjs
            self._worst = [None] * nobjs

        def first_is_better(a, b, i_obj):
            if senses[i_obj] == "min":
                return a < b
            elif senses[i_obj] == "max":
                return a > b
            else:
                raise ValueError(f"Invalid sense: {senses[i_obj]}")

        def first_is_worse(a, b, i_obj):
            if senses[i_obj] == "min":
                return a > b
            elif senses[i_obj] == "max":
                return a < b
            else:
                raise ValueError(f"Invalid sense: {senses[i_obj]}")

        best_sln_indices = [batch.argbest(i) for i in range(nobjs)]
        worst_sln_indices = [batch.argworst(i) for i in range(nobjs)]

        for i_obj in range(nobjs):
            print(i_obj)
            best_sln_index = best_sln_indices[i_obj]
            worst_sln_index = worst_sln_indices[i_obj]
            scores = batch.access_evals(i_obj)
            best_score = scores[best_sln_index]
            worst_score = scores[worst_sln_index]
            if first_is_better(best_score, self._best_evals[i_obj], i_obj):
                self._best_evals[i_obj] = best_score
                self._best[i_obj] = batch[best_sln_index].clone()
            if first_is_worse(worst_score, self._worst_evals[i_obj], i_obj):
                self._worst_evals[i_obj] = worst_score
                self._worst[i_obj] = batch[worst_sln_index].clone()

        if len(senses) == 1:
            return dict(
                best=self._best[0],
                worst=self._worst[0],
                best_eval=float(self._best[0].evals[0]),
                worst_eval=float(self._worst[0].evals[0]),
            )
        else:
            return {"best": self._best, "worst": self._worst}