facebookresearch / CodeGen

Reference implementation of code generation projects from Facebook AI Research. General toolkit to apply machine learning to code, from dataset creation to model training and evaluation. Comes with pretrained models.
MIT License
710 stars 144 forks source link

[Question]retrained the original TransCoder model, translation was not good #96

Open kaisawind opened 1 year ago

kaisawind commented 1 year ago

Hi,

We recently trained the original transcoder model (python-java only), and while the metrics looked good, the translation was not good.

Using TransCoder_model_1 for translation, it is possible to translate java classes into python classes, but using our trained model cannot translate classes, only standalone functions can be translated.

  1. Why TransCoder_model_1 can translate classes?
  2. Is our model still not trained enough epoch?
  3. Is our model needing more training data?
  4. Is our model using the wrong datasets?

Translate

class Solution:
    def count_components(self, n: int, edges: List[List[int]]) -> int:
        graph = self.build_graph(n, edges)

        visited = set()
        num = 0
        for node in range(n):
            if node in visited:
                continue
            self.traversal_with_bfs(node, graph, visited)
            num += 1
        return num

    def build_graph(self, n, edges):
        graph = {node:[] for node in range(n)}
        for pre_node, node in edges:
            graph[pre_node].append(node)
            graph[node].append(pre_node)
        return graph

    def traversal_with_bfs(self, node, graph, visited):
        queue = collections.deque([node])
        visited.add(node)
        while queue:
            current_node = queue.popleft()
            for neighbor in graph[current_node]:
                if neighbor in visited:
                    continue
                queue.append(neighbor)
                visited.add(neighbor)

TransCoder_model_1

public static class Solution {
  int countComponents ( final int n , final List < List < Integer >> edges ) throws Exception {
    final List < List < Integer >> graph = buildGraph ( n , edges ) ;
    final Set < Integer > visited = new HashSet < Integer > ( ) ;
    int num = 0 ;
    for ( int i = 0 ;
    i < n ;
    i ++ ) {
      if ( i < visited . size ( ) ) {
        continue ;
      }
      traversalWithBfs ( i , graph , visited ) ;
      num ++ ;
    }
    return num ;
  }
  public void buildGraph ( final int n , final List < List < Integer >> edges ) throws Exception {
    final Map < Integer , List < Integer >> graph = new HashMap < Integer , List < Integer >> ( ) ;
    for ( int i = 0 ;
    i < n ;
    i ++ ) {
      graph . put ( i , new ArrayList < Integer > ( ) ) ;
      graph . get ( i ) . add ( i ) ;
    }
    traversalWithBfs ( n , graph , visited ) ;
  }
  public void traversalWithBfs ( final int node , final List < List < Integer >> graph , Set < Integer > visited ) throws Exception {
    final Queue < List < Integer >> queue = new LinkedList < List < Integer >> ( ) ;
    visited . add ( node ) ;
    while ( queue . size ( ) > 0 ) {
      final List < Integer > currentNode = queue . poll ( ) ;
      for ( final List < Integer > neighbor : graph . get ( currentNode ) ) {
        if ( neighbor . contains ( node ) ) {
          continue ;
        }
        queue . add ( neighbor ) ;
        visited . add ( neighbor ) ;
      }
    }
  }
}

TransCoder_my

static void main ( ) {
  new Solution ( ) . countComponents ( ) ;
  for ( Node node : new LinkedList < > ( ) ) {
    if ( node . visited . add ( node ) ) {
      new Solution ( ) . queue ( ) ;
    }
  }
  new Solution ( ) . queue ( ) ;
  new Solution ( ) . queue ( ) ;
  new Solution ( ) . queue ( ) ;
  new Solution ( ) . queue ( ) ;
  new Solution ( ) . queue ( ) ;
}

Env

2xGTX3090(24G) java(50G)-python(50G)

Monolingual -> MLM Monolingual Functions -> TransCoder(from pretrained MLM)

Changed params https://github.com/facebookresearch/CodeGen/issues/12#issuecomment-910566452

--n_layers 6 
--emb_dim 1024 
--n_heads 8 

Results

Model/Task Java -> Python Python -> Java
Beam Size k=1 k=10 k=1 k=10
TransCoder_model_1 46.87 48.81 33.89 35.55
TransCoder_model_2 46.87 47.73 32.64 35.97
TransCoder from DOBF 49.24 52.7 39.5 45.32
TransCoder_my(epoch 430) 73.267327 85.148515 58.585859 72.727273