LaurentMazare / tch-rs

Rust bindings for the C++ api of PyTorch.
Apache License 2.0
4.31k stars 342 forks source link

Is it possible to use ModuleDict? #402

Open dbsxdbsx opened 3 years ago

dbsxdbsx commented 3 years ago

For a more flexible torch network, sometimes ModuleDicts would be used. I saw the c++ api is release since pytorch 1.8, related issue, but I could not find related document here, is it possible to do it with tch-rs?

LaurentMazare commented 3 years ago

There is no strict equivalent to ModuleDicts here but I feel that the Sequential api let's you build model in a pretty similar way (and has the advantage of being more explicit on variable naming), e.g. the following example comes from the vgg implementation.

    seq.add_fn(|xs| xs.flat_view())
        .add(nn::linear(&c / "0", 512 * 7 * 7, 4096, Default::default()))
        .add_fn(|xs| xs.relu())
        .add_fn_t(|xs, train| xs.dropout(0.5, train))
        .add(nn::linear(&c / "3", 4096, 4096, Default::default()))
        .add_fn(|xs| xs.relu())
        .add_fn_t(|xs, train| xs.dropout(0.5, train))
        .add(nn::linear(&c / "6", 4096, nclasses, Default::default()))

Do you think that ModuleDict would be much of a gain here? Any thoughts on how this would look like?

dbsxdbsx commented 3 years ago

@LaurentMazare, the vgg code I've checked before, because I want to know how to set train/eval mode for a torch model.
And what I want here is a network that could not only update weights through classic gradient descent, but could also change its struct---some part of the network could hold parameter, while other parts could change struct with weights. The idea of changeable struct is from this paper: NEAT , related video.

And I did make such a network through python class with pytorch(Not all code are represented, but I know it is still complex, take it easy plz):

class MyBlock(nn.Module):
    """
    this is the core subModule class that would modify structre of network
    """
    def __init__(self, genome, state_dim, action_dim):
    """
    each genmoe is a network candidate,genome.connection_gene is the weight between 2 nodes.
    """
        super(MyBlock, self).__init__()
        self.genome = genome  # use reference here
        self.state_dim = state_dim
        self.action_dim = action_dim

        self.layer_map = torch.nn.ModuleDict()
        for i, each_conn_gene in enumerate(self.genome.connection_genes):
            if not each_conn_gene.is_enabled:
                continue
            # 1.init a new layer for this connection_gene
            new_layer = nn.Linear(1, 1, bias=False)

            # 2.reset the layer weights as in the connection_gene
            # NOTE: weight is a torch tensor. Shape is ([1]),like [1.2]
            weights = each_conn_gene.weight
            for p in new_layer.parameters():
                p.data = weights

            # 3.store it to pytorch module dict with innov num
            # 3.1 set connection_gene name,---so that user could update weight "accordingly"
            each_conn_gene.name = str(i)
            key_name = str(each_conn_gene.name)  # NOTE: not .innov_num
            #NOTE:  can refer to it by custom_Layer.layer_map.[key_name].weight
            # print(f"key name: {key_name}")
            self.layer_map[key_name] = new_layer

        self.node_value_map = self.init_node_value_map()

        # init node value lists
        self.input_node_id_list = [
            node.id for node in self.genome.node_genes if node.type == 'input'
        ]
        self.output_node_id_lst = [
            node.id for node in self.genome.node_genes if node.type == 'output'
        ]

    def init_node_value_map(self):
        return {node.id: None for node in self.genome.node_genes}

    def get_node_value(self, node_id):
"""
recursivley calculate node value from input node to this input node_id.
"""
        try:
            if self.node_value_map[node_id] is not None:
                return self.node_value_map[node_id]

            # try to find the node value
            # NOTE: get_connections_in would ignore disabled connection_gene
            connect_gene_lst = self.genome.get_connections_in(node_id)

            final_tensor = None
            for each_conn_gene in connect_gene_lst:
                # print(each_conn_gene)
                in_node_value_tensor = self.get_node_value(
                    each_conn_gene.in_node_id)
                assert in_node_value_tensor is not None
                assert len(in_node_value_tensor.shape) >= 2
                # xx = self.layer_map.keys()
                layer = self.layer_map[each_conn_gene.name]
                # layer = self.layer_map[str(each_conn_gene.innov_num)]
                each_conn_output_tensor = layer(
                    in_node_value_tensor)  # input tensor shape should be [bat_size,1]
                #
                assert len(
                    each_conn_output_tensor.shape) == 1  # even for batch_size>1
                each_conn_output_tensor = each_conn_output_tensor.unsqueeze(
                    1)  # add batch dim

                # add all values that forward to this node_id, then the final value is the value of  node_id
                if final_tensor is None:  # frist time
                    final_tensor = each_conn_output_tensor
                else:
                    final_tensor += each_conn_output_tensor

            assert final_tensor is not None
            self.node_value_map[node_id] = final_tensor
        except Exception as e:
            j = 9
        return final_tensor

    def forward(self, x):
        self.node_value_map = self.init_node_value_map()

        assert len(x.shape) >= 2  #one for batchsize dim,one for element dimension
        # 1.input node value setting
        assert len(self.input_node_id_list) == x.shape[1]
        splited_input = torch.tensor_split(x, self.state_dim, dim=1)
        for each_node_id in self.node_value_map.keys():
            if each_node_id in self.input_node_id_list:
                self.node_value_map[each_node_id] = splited_input[each_node_id]

        # 2.get each output node value
        output_node_value_lst = [self.get_node_value(output_id) \
            for output_id in self.output_node_id_lst]
        assert len(output_node_value_lst) == self.action_dim

        # NOTE: after testing,list of torch tensor could still trick gradient after concating
        output = torch.stack(output_node_value_lst, dim=0)
        assert output.shape[0] == self.action_dim
        assert output.shape[1] == x.shape[0]  # batch size
        assert output.shape[2] == 1

        output = output.view(output.shape[0],
                             output.shape[1])  # b share same memory with t
        output = output.permute(1, 0)  # TODO: maybe failed for multi dim action,transpose shape
        return output

class ActorCriticModel(nn.Module):
    def __str__(self):
        return (
            f"model input node num:{self.input_node_genes_num}\n"
            f"model output node num:{self.output_node_genes_num}\n"
            f"model hidden node num:{self.hidden_node_genes_num}\n"
            f"model connect num/enabled/disabled:{self.connection_genes_num}/"
            f"{self.enable_connection_genes_num}/"
            f"{self.disable_connection_genes_num}\n"
            f"model score:{self.score}")

    def __init__(self, state_dim, action_dim, action_lim=None):
        super(ActorCriticModel, self).__init__()
        self.state_dim = state_dim  
        self.action_dim = action_dim
        self.action_lim = action_lim
        self.is_discrete_action = action_lim == None
        self.log_prob = None  #TODO:check

        #todo:
        self.SCALE_ACTIVATION = 4.9
        ACTIVATION = 'sigmoid'
        self.activation = a.Activations().get(ACTIVATION)

        # ↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓ tail layer part ↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓
        hid_size = self.action_dim * 3  #TODO:what is the best size?
        if self.is_discrete_action:
            self.distribution = torch.distributions.Categorical

            # new version
            # self.Q1_layer1 = nn.Linear(hid_size, self.action_dim)
            # self.Q1_layer2 = nn.Linear(hid_size, self.action_dim)

            # old version
            self.Q1_layer1 = nn.Linear(self.action_dim, hid_size)
            self.Q1_layer3 = nn.Linear(hid_size, self.action_dim)

            self.Q2_layer1 = nn.Linear(self.action_dim, hid_size)
            self.Q2_layer3 = nn.Linear(hid_size, self.action_dim)

            self.tail_layer_lst = [
                self.Q1_layer1, self.Q1_layer3, self.Q2_layer1, self.Q2_layer3
            ]
        else:  #continuous action
            self.distribution = torch.distributions.Normal
            # ---------------------- actor layer ----------------------
            self.mean_layer = nn.Linear(self.action_dim, self.action_dim)
            self.log_std_layer = nn.Linear(self.action_dim, self.action_dim)
            # ---------------------- critic layer ----------------------
            # for Q1
            self.Q1_layer1 = nn.Linear(1 + self.action_dim, hid_size)
            self.Q1_layer3 = nn.Linear(hid_size, 1)
            # for Q2
            self.Q2_layer1 = nn.Linear(1 + self.action_dim, hid_size)
            self.Q2_layer3 = nn.Linear(hid_size, 1)

            self.tail_layer_lst = [
                self.mean_layer, self.log_std_layer, self.Q1_layer1,
                self.Q1_layer3, self.Q2_layer1, self.Q2_layer3
            ]

        #  discrete or continuous
        self.forward_func = self.forward_discrete_version if self.is_discrete_action \
        else self.forward_continuous_version
        # ↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑ tail layer part ↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑

        #
        self.genome = None
        self.custom_Layer = None
        self.update_model_struct()

    def copy_model(self, genome=None, copy_tail_layer=True):
        device = next(
            self.parameters()).device  # get same device as of self model
        new_model = ActorCriticModel(self.state_dim, self.action_dim,
                                     self.action_lim).to(device)
        # # genome layer copy
        if genome is not None:
            new_model.update_model_struct(genome)
        else:
            new_model.update_model_struct(self.genome)
            # NOTE:for parameters from pytorch classic layer(NOT GENOME), the parameter need to copied again here
            # #TODO: check whether paramter is changed after copying?
            # new_model.load_state_dict(copy.deepcopy(self.state_dict()))
            # j=9
        if copy_tail_layer:
            assert len(self.tail_layer_lst) == len(new_model.tail_layer_lst)
            for i, each_layer in enumerate(self.tail_layer_lst):
                new_model.tail_layer_lst[i] = copy.deepcopy(each_layer)

        return new_model

    @property
    def connection_genes_num(self):
        return len(self.genome.connection_genes)

    @property
    def enable_connection_genes_num(self):
        return len(
            self.genome.connection_genes) - self.disable_connection_genes_num

    @property
    def disable_connection_genes_num(self):
        cnt = 0
        for each_conn in self.genome.connection_genes:
            if not each_conn.is_enabled:
                cnt += 1
        return cnt

    @property
    def node_genes_num(self):
        return len(self.genome.node_genes)

    @property
    def input_node_genes_num(self):
        return len([
            each_node for each_node in self.genome.node_genes
            if each_node.type == 'input'
        ])

    @property
    def output_node_genes_num(self):
        return len([
            each_node for each_node in self.genome.node_genes
            if each_node.type == 'output'
        ])

    @property
    def hidden_node_genes_num(self):
        return self.node_genes_num - self.input_node_genes_num - self.output_node_genes_num

    @property
    def score(self):
        # return self.test_score
        return self.genome.fitness

    def set_score(self, new_score):
        self.genome.fitness = new_score
        # self.test_score = new_score

    def _initial_genome(self):
        new_genome = Genome()
        inputs = []
        outputs = []
        bias = None

        # Create nodes
        for _ in range(self.state_dim):
            n = new_genome.add_node_gene('input')
            inputs.append(n)

        for _ in range(self.action_dim):
            n = new_genome.add_node_gene('output')
            outputs.append(n)

        # NOTE: I don't like using bias
        # if self.Config.USE_BIAS:
        # bias = new_genome.add_node_gene('bias')

        # Create connections
        for input in inputs:
            for output in outputs:
                new_genome.add_connection_gene(input.id, output.id)

        if bias is not None:
            for output in outputs:
                new_genome.add_connection_gene(bias.id, output.id)

        return new_genome

    def update_model_struct(self, genome=None):
        """
        This method would be used by user, to update model struct.
        """

        if genome is None:
            self.genome = self._initial_genome()
        else:
            self.genome = copy.deepcopy(genome)

        # genome unfixed layer
        self.custom_Layer = MyBlock(self.genome, self.state_dim,
                                    self.action_dim)
        # print("********************")
        # for name,param in self.named_parameters():# this would print all parameters, even for those that not used in forward function
        #     print(f"{name}---->{param}\n")
        # j=9

    def update_genome_weight(self):
        """
      according to sequence of the moduleDict defined in Block Layer, update these params back to genome.connection genes.
        """
        found_cnt = 0
        for each_conn_gene in self.genome.connection_genes:
            if not each_conn_gene.is_enabled:
                continue
            for name, param in self.named_parameters(
            ):  # this would print all parameters, even for those that not used in forward 
                # print(f"{name}---->{param}\n")
                split_lst = name.split(".")
                if len(split_lst
                       ) == 4 and split_lst[1] == 'layer_map' and split_lst[
                           2] == each_conn_gene.name:
                    # if each_conn_gene.name == name:
                    each_conn_gene.weight = param.clone().detach()
                    found_cnt += 1
                    # log(f'found for {name}','green')
                    break

        if self.enable_connection_genes_num != found_cnt:
            for name, param in self.named_parameters(
            ):  # this would print all parameters, even for those that not used in forward 
                print(f"{name}---->{param}\n")
            print("================================================")
        # log("model param is updated back to genome","yellow")

    def forward(self, x, actions=None):
        """
        forward [summary]

        Args:
            x ([type]): [description]
            actions ([type], optional): [description]. Defaults to None.

        Returns:
            for discrete action version:action_tensor, action_probs, qf1_pi, qf2_pi, q1_score, q2_score

        """
        tmp_tensor = self.custom_Layer(x)
        output = self.forward_func(tmp_tensor, actions)
        return output

Generally, in this setting, this kind of network would be more flexible, because the STRUCTURE of the network is CHANGEABLE during training with MyBlock and ModuleDict within it. ModuleDict is essential here, in that it seems to be the only way to track gradient of custom flexible layer(which may usually with struct of more anomal than that of resenet, densenet...) when using torch . Also this network, as an object, has a method update_model_struct for user to change struct. So I guess Sequential style you mentioned may not be feasible here.