oegedijk / explainerdashboard

Quickly build Explainable AI dashboards that show the inner workings of so-called "blackbox" machine learning models.
http://explainerdashboard.readthedocs.io
MIT License
2.3k stars 331 forks source link

Update component plots when selecting data #292

Closed soundgarden134 closed 8 months ago

soundgarden134 commented 9 months ago

Hello, I'm making a custom dashboard with ExplainerDashboard components and a map. The idea is to be able to select a region in the map to filter the data and re calculate the shap values in order to understand a certain area's predictions by seeing the feature importances in this area in particular. However, since I'm not an expert in Dash I haven't been able to update the components. After being initialized correctly, once I select an area of the map and trigger the callback, the component plots end up empty. This is my (shortened) code:

dash.py (omitting initial setup)

app = Dash(__name__)
server = app.server
map_tab = RegressionDashboard(consolidated, eb_explainer, model, model_type, name="Regression Dashboard", app=app)

app.layout = html.Div([
    map_tab.layout()
])

map_tab.register_callbacks(app)

if __name__ == "__main__":
    log.info('Starting dashboard server ...')
    app.run(port=6660, host='0.0.0.0')

regression_dashboard.py

class RegressionDashboard(ExplainerComponent):
    def __init__(self, consolidated, explainer,  model, model_type, app, source_crs='EPSG:32719',name=None,**kwargs):
        super().__init__(explainer, title="Map")
        # a lot of self.(something) lines
        self.contrib = ShapContributionsGraphComponent(explainer,
                            hide_selector=True, hide_cats=True,
                            hide_depth=True, hide_sort=True,
                            **kwargs)
        self.shap_summary = ShapSummaryComponent(explainer, hide_selector=True, hide_cats=True,
                            hide_depth=True, hide_sort=True, hide_type=True,
                            **kwargs) #Feature importances basically, edit title 
        self.shap_dependance = ShapDependenceComponent(explainer, hide_selector=True, hide_cats=True,
                            hide_depth=True, hide_sort=True, plot_sample=100000,
                            **kwargs)
        self.shap_dependance_connector = ShapSummaryDependenceConnector(self.shap_summary, self.shap_dependance)

 #terrible layout, just for testing purposes
    def layout(self):
        self.map_fig = self.create_map()

        return html.Div(
            html.Div([
                html.Div(
                dcc.Graph(figure=self.map_fig, id="preds_map", style={'height': '45vh'}),
                style={
                    'width': '50%', 
                    'display': 'inline-block', 
                    'border': 'thin lightgrey solid', 
                    'boxSizing': 'border-box',
                    'height': '50vh'
                }
                ),
                html.Div([
                    self.contrib.layout(),
                    self.shap_summary.layout(),
                    self.shap_dependance.layout(),
                ],
                )
            ],
            style={
                'width': '100%', 
                'height': '60vh'
        }),
        id='layout-container')

    def update_layout_components(self):
        return html.Div([
            html.Div(
                dcc.Graph(figure=self.map_fig, id="preds_map", style={'height': '45vh'}),
                style={
                    'width': '50%', 
                    'display': 'inline-block', 
                    'border': 'thin lightgrey solid', 
                    'boxSizing': 'border-box',
                    'height': '50vh'
                }
            ),
            html.Div([
                self.contrib.layout(),
                self.shap_summary.layout(),
                self.shap_dependance.layout(),
            ]),
        ], 
        style={
            'width': '100%', 
            'height': '60vh'
        })

    def create_map(self, filtered_data = None, max_points = None):

      #map code, irrelevant
        return fig

    def transform_coordinates(self, df, x_col, y_col, source_crs):
    # transform coordinates from one system to another, irrelevant
        return df

#I want to filter by coordinates but right now I'm just trying to update the plots by just making a 
# random subsample of the data to prove
# the plots are updating
    def update_components(self):
        predictor = self.model.steps[-1][1]
        X_transformed, blockids = consolidated_to_X(self.consolidated.sample(n=3000, random_state=42), self.model)
        X_transformed.drop(['long', 'lat'], axis=1, inplace=True)
        explainer = RegressionExplainer(model=predictor, X=X_transformed, n_jobs=-1, index_name="Block ID", 
                                precision="float32", target="DEPVAR")
        shap_explainer = shap.Explainer(predictor, X_transformed)
        shap_values = shap_explainer.shap_values(X_transformed, check_additivity=False, approximate=True)
        base_values = shap_explainer.expected_value
        explainer.set_shap_values(base_values, shap_values)
        self.contrib = ShapContributionsGraphComponent(explainer,
                            hide_selector=True, hide_cats=True,
                            hide_depth=True, hide_sort=True,
                            )
        self.shap_summary = ShapSummaryComponent(explainer, hide_selector=True, hide_cats=True,
                            hide_depth=True, hide_sort=True, hide_type=True,
                            ) #Feature importances basically, edit title 
        self.shap_dependance = ShapDependenceComponent(explainer, hide_selector=True, hide_cats=True,
                            hide_depth=True, hide_sort=True, plot_sample=100000,
                            )
        self.shap_dependance_connector = ShapSummaryDependenceConnector(self.shap_summary, self.shap_dependance)

    def component_callbacks(self, app):
        @app.callback(
        Output('layout-container', 'children'),
        Input('preds_map', 'selectedData'),
        prevent_initial_call=True)
        def update_selected_data(selectedData):
            if not selectedData:
                raise PreventUpdate
            self.update_components()
            new_layout = self.update_layout_components()
            return new_layout

What am I missing here? I know there's probably a lot of unnecesary code here and it's really messy, but I'm really losing my mind over this. Any help is greatly appreciated. Thanks!

datatalking commented 9 months ago

Please show a screenshot or copy of the error you get, once we can see what your output is I could pair program with you to see if I can help.

oegedijk commented 9 months ago

It looks like you are reinstantiating all the sub components every time there is an update? I don't think that will work. All the components have to be there at the start of the dash app, and then the callbacks update the properties of the components, so the self.update_components() line is probably that breaks things. So get rid of that. The shap values have already been pre-calculated in the explainer that gets passed to the component, so you never reinstantiate the explainer and then recalculate.

So the callback should probably target Output('preds_map', 'figure') instead and then you simply generate the right plot for a subsample.

soundgarden134 commented 8 months ago

Hello, and sorry for the late response. I think I didn't explain properly what the problem was. What I want to do basically is to update the plot ShapSummaryComponent with a callback in my custom tab. Whenever I select an area in the preds_map, a new explainer should be created with the filtered data within the map, changing the calculated shap values of the features. WIth this new explainer, I would like to update the ShapSummaryComponent in order to provide insights about particular areas and which features had the biggest impact in these specific areas. Is this possible? Thanks in advance.

UPDATE: I managed to do it. What I did was basically use a callback in my map that recalculates the shap values for the selected area, re calculates the explainer and sets the new explainer for the shap summary component and updates a hidden "P" html element once its finished its calculations. Then I added in the update_shap_summary_graph another input with this element that I named "shap-hidden-trigger" and every time this element changes, the callback is called again with the new explainer.

shap_components.py

        @app.callback(
            [
                Output("shap-summary-graph-" + self.name, "figure"),
                Output("shap-summary-index-col-" + self.name, "style"),
            ],
            [
                Input("shap-summary-type-" + self.name, "value"),
                Input("shap-summary-depth-" + self.name, "value"),
                Input("shap-summary-index-" + self.name, "value"),
                Input("pos-label-" + self.name, "value"),
                Input("shap-hidden-trigger", "children"),
            ],
        )
        def update_shap_summary_graph(summary_type, depth, index, pos_label, hidden_trigger):
         #rest of the code

my_custom_layout.py

  @app.callback(
            Output('shap-hidden-trigger', 'children'),
            Input('update-shap-button', 'n_clicks'),
            State('preds-map', 'selectedData'),
            prevent_initial_call=True,
        )
        def update_shap_summary(change_settings, selectedData):
             #some processing of the data that I skipped

                shap_explainer = shap.Explainer(predictor)
                shap_values = shap_explainer.shap_values(
                    new_X, 
                    check_additivity=False, 
                    approximate=False
                )
                base_values = shap_explainer.expected_value
                new_explainer = ClassifierExplainer(model=predictor, X=new_X, n_jobs=-1, index_name="Block ID", 
                                                    precision="float32", target="DEPVAR")
                new_explainer.set_shap_values(base_values, shap_values)
                self.shap_summary.explainer = new_explainer
                return "Turing" #just updating to whatever since its hidden

Thanks!