pytorch / captum

Model interpretability and understanding for PyTorch
https://captum.ai
BSD 3-Clause "New" or "Revised" License
4.73k stars 476 forks source link

How to calculate integrated gradients of each states in DRL actor network(like SAC algorithm)? #1123

Open 1900360 opened 1 year ago

1900360 commented 1 year ago

I suppose SAC algorithm has one actor network and two critic network, now I want to rank the DRL states importance by calculate integrated gradients of each states to sork the states. so I wound if there is any possible to calculate like this:

        self.actor.eval()      #actor network
        obs = torch.tensor(batch.obs, requires_grad=True)   # observations   and open the obs_grad
        torch.manual_seed(123)          
        np.random.seed(123) 
        input = torch.rand(2, 3)     
        baseline = torch.zeros((obs).shape, 3)    #don't know baseline meaning
        attributions, delta = self.integrated_gradients.attribute(
            obs,baseline, target=0, return_convergence_delta=True)    #calculate integrated gradients of each states, but didn't work
        obs.requires_grad_(False)        #close the obs_grad

Can you give me some suggestions? I really need that:)

1900360 commented 1 year ago

anybody here?