forestry-labs / Rforestry

https://forestry-labs.github.io/Rforestry/
34 stars 11 forks source link

Getting weights variables #100

Closed petrovicboban closed 1 year ago

petrovicboban commented 1 year ago

Can you fix function definition here

    def _get_weights_variables(self, weights: np.ndarray) -> np.ndarray:
        weights_variables = [i for i in range(weights.size) if weights[i] > max(weights) * 0.001]
        if len(weights_variables) < self.mtry:
            raise ValueError("mtry is too large. Given the feature weights, can't select that many features.")

        weights_variables = np.array(weights_variables, dtype=np.ulonglong)
        return weights_variables

for cases where weights is None (which is default value in fit() for feature_weights and deep_feature_weights) ? Also, function argument weight should be Optional[np.ndarray]

@edwardwliu

petrovicboban commented 1 year ago

Let's forget this for now.