guxd / deep-code-search

DeepCS: Deep Code Search
MIT License
279 stars 85 forks source link

the reasult not well #49

Open primary-studyer opened 4 years ago

primary-studyer commented 4 years ago

dcs->epcohStep=260000 top 10 ACC=0.767, MRR=0.32433587301587297, MAP=0.32433587301587297, nDCG=0.42961201846689157 top 5 ACC=0.6995, MRR=0.44343166666666667, MAP=0.44343166666666667, nDCG=0.5078651066901124 top 1 ACC=0.4761, MRR=0.4761, MAP=0.4761, nDCG=0.4761 数据集是codesearchnet中提供的Java数据,这是我训练过程达到的最优结果,poolsize设置的1000,达不到您之前说的结果要在0.9以上。 我将数据集按您所说划分为train和valid部分。感觉valid起的作用和test部分一样。 执行search的结果非常糟糕。我应该如何解决这个问题,使得search结果明显一些? 我之前用了您提供的epoch500来在大的codebase运行的时候,结果也是相关的比较少。我当时没找到原因,现在到我自己处理的时候,结果也这样,非常期待回复。

guxd commented 4 years ago

请问你是用的pytorch版吗? 你的数据集可能偏小,需要重新调参。 我提供epoch500的时候可能还没有用automl调参,后来测的pytorch能达到0.9以上。 另外poolsize设为10,000或100,000更合理。

primary-studyer commented 4 years ago

请问你是用的pytorch版吗? 你的数据集可能偏小,需要重新调参。 我提供epoch500的时候可能还没有用automl调参,后来测的pytorch能达到0.9以上。 另外poolsize设为10,000或100,000更合理。

这个数据集用于训练部分数据是23w左右,验证部分数据量1.5w左右。我想先在poolsize=1000达到0.9以后,再试试10000的。 关于调参 您有什么建议么。

        #parameters
        'name_len': 6,
        'api_len':30,
        'tokens_len':50,
        'desc_len': 30,
        'n_words': 10000, # len(vocabulary) + 1
        #vocabulary info
        'vocab_name':'vocab.name.json',
        'vocab_api':'vocab.apiseq.json',
        'vocab_tokens':'vocab.tokens.json',
        'vocab_desc':'vocab.desc.json',

    #training_params            
        'batch_size': 64,
        'chunk_size':200000,
        'nb_epoch': 15,
        #'optimizer': 'adam',
        'learning_rate':2.08e-4,
        'adam_epsilon':1e-8,
        'warmup_steps':5000,
        'fp16': False,
        'fp16_opt_level': 'O1', #For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3'].
                        #"See details at https://nvidia.github.io/apex/amp.html"

    # model_params
        'emb_size': 512,
        'n_hidden': 512,#number of hidden dimension of code/desc representation
        # recurrent
        'lstm_dims': 256, # * 2          
        'margin': 0.3986,
        'sim_measure':'cos',#similarity measure: cos, poly, sigmoid, euc, gesd, aesd. see https://arxiv.org/pdf/1508.01585.pdf
                     #cos, poly and sigmoid are fast with simple dot, while euc, gesd and aesd are slow with vector normalization.
}
return conf
guxd commented 4 years ago

调参需要专用平台。手工调参的话可以参考automl_config.yaml末尾关于参数的区间。 我再用我们的训练数据跑一下,给你一个训练好的模型,你用训练好的模型来测试你的数据看看效果。

primary-studyer commented 4 years ago

调参需要专用平台。手工调参的话可以参考automl_config.yaml末尾关于参数的区间。 我再用我们的训练数据跑一下,给你一个训练好的模型,你用训练好的模型来测试你的数据看看效果。

感谢。我的数据集小确实不太容易操作。

guxd commented 4 years ago

https://drive.google.com/file/d/15HoKv0efrVXNTsqCxoq2Swgh6ohuq5jI/view?usp=sharing 这里是训练好的一个模型,pool size选的10000, top-1精度结果如下 image

primary-studyer commented 4 years ago

https://drive.google.com/file/d/15HoKv0efrVXNTsqCxoq2Swgh6ohuq5jI/view?usp=sharing 这里是训练好的一个模型,pool size选的10000, image

因为我自己模型query不是很好,从头捋一遍的时候发现一个问题 例如convert inputstream to string 方法名为inputstreamToString,我分词为inputstream to string 存储到.name.h5中。 关于desc.h5我将存为inputstream还是input stream? 那么token.h5关于InputStream语句我存inputstream还是input stream? 因为有的是整体不用分割,有的词是需要分割的?这种情况我该怎么处理。因为我无法确定哪些词是部分拆分的。

guxd commented 4 years ago

我们简单的对代码里的token作了camel split, query没有拆分,你可以都试试。

primary-studyer commented 4 years ago

我们简单的对代码里的token作了camel split, query没有拆分,你可以都试试。 我在另一个数据集训练dcs的时候,acc top 10 poolsize-1000,打到了0.90+ 并且我query的时候将inputstream to string,转换成input stream to string来查询。其中top10最高相似度为0.90左右,但是结果并不相干。 当我把codebase等文件其他项删除 只留一个InputStreamToString代码片段时候,即len(codebase)==1,cos是0.93,但是在上面这个结果并没有出现,最高才0.9+。 0.93的那段并没有显示出来。

guxd commented 4 years ago

谢谢你提供的线索,可能代码还存在bug。如果找到原因麻烦您告知。

primary-studyer commented 4 years ago

谢谢你提供的线索,可能代码还存在bug。如果找到原因麻烦您告知。

原因貌似在repr_code.py中data_loader = torch.utils.data.DataLoader(dataset=use_set, batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=1)的batch_size,您的代码默认是10000.我执行query语句input stream to `string,top5结果如下: (这个是我把codebase length的大小做成为5的代码库,就5个数据,当codebase是全部代码片段时,第一的这个0.9439719这条相似度变成0.72左右)

`('public static String inputStreamToString(InputStream is) throws MPException { String value = ""; if (is != null) {
try { ByteArrayOutputStream result = new ByteArrayOutputStream(); byte[] buffer = new byte[1024]; int length;
while ((length = is.read(buffer)) != -1) { result.write(buffer, 0, length); } value = result.toString("UTF-8");
} catch (Exception ex) {throw new com.mercadopago.exceptions.MPException(ex); } } return value; }\r\n', 0.9439719)

('private static String inputStreamToString(InputStream in) { try { BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(in));
StringBuilder stringBuilder = new StringBuilder(); String line = null; while ((line = bufferedReader.readLine()) != null) {
stringBuilder.append(line + "\n"); } bufferedReader.close(); return stringBuilder.toString(); }
catch (IOException e) { throw new RuntimeException("Failed to parse input stream", e); } }\r\n', 0.66112924)

('public static String inputStreamToString(InputStream in) { StringBuffer buffer = new StringBuffer(); try { BufferedReader br = new BufferedReader(new InputS treamReader(in, "UTF-8"), 1024); String line; while ((line = br.readLine()) != null) {
buffer.append(line); } } catch (IOException iox) { LOGR.warning(iox.getMessage()); }
return buffer.toString(); }\r\n', 0.6445415)

('public static String inputStreamToString(InputStream in) { Scanner scanner = new Scanner(in, "UTF-8"); String content = scanner.useDelimiter("\\A").next(); scanner.close(); return conte nt; }', 0.49575543) ('public static String readInputStreamToString( InputStream inputStream ) {try {List bytesList = new ArrayList(); byte b = 0; while( (b = (byte) inputStream.read()) != -1 ) {bytesList.add(b); } inputStream.close(); byte[] bArray = new byte[bytesList.size()]; for( int i = 0; i < bArray.length; i++ ) {bArray[i] = bytesList.get(i);} String file = new String(bArray); return file; } catch (IOException e) { e.printStackTrace();
return null; } }\r\n', 0.48657355)`

当我把batchzise设置成1. top10结果如下 ('public static int cuStreamWriteValue32 ( CUstream stream, CUdeviceptr addr, int value, int flags ) { return checkResult ( cuStreamW riteValue32Native ( stream, addr, value, flags ) ) ; } \n', 0.9701143) ('@Override protected Response encode ( final SamlRegisteredService service, final Response samlResponse, final HttpServletResponse h ttpResponse, final HttpServletRequest httpRequest, final SamlRegisteredServiceServiceProviderMetadataFacade adaptor, final String relayState, final String binding, final RequestAbstractType authnRequest, final Object assertion ) throws SamlException { LOGGER . trace ( " " , binding, adaptor . getEntityId ( ) ) ; if ( binding . equalsIgnoreCase ( SAMLConstants . SAML2_ARTIFACT_BINDING_URI ) ) { val encoder = new SamlResponseArtifactEncoder ( getSamlResponseBuilderConfigurationContext ( ) . getVelocityEngineFactory ( ) , adaptor, httpRequest, httpResponse, authnRequest, getSamlResponseBuilderConfigurationContext ( ) . getTicketRegistry ( ) , getSamlResponseBuilderConfigurationContext ( ) . getSamlArtifactTicketFactory ( ) , getSamlResponseBuilderConfigurationContext ( ) . getTicketGrantingTicketCookieGenerator ( ) , getSamlResponseBuilderConfigurationContext ( ) . getSamlArtifactMap ( ) ) ; return encoder . encode ( authnRequest, samlResponse, relayState ) ; } if ( binding . equalsIgnoreCase ( SAMLConstants . SAML2_POST_SIMPLE_SIGN_BINDING_URI ) ) { val encoder = new SamlResponsePostSimpleSignEncoder ( getSamlResponseBuilderConfigurationContext ( ) . getVelocityEngineFactory ( ) , adaptor, httpResponse, httpRequest ) ; return encoder . encode ( authnRequest, samlResponse, relayState ) ; } val encoder = new SamlResponsePostEncoder ( getSamlResponseBuilderConfigurationContext ( ) . getVelocityEngineFactory ( ) , adaptor, httpResponse, httpRequest ) ; return encoder . encode ( authnRequest, samlResponse, relayState ) ; } \n', 0.9691605) ('public static int cuStreamWriteValue64 ( CUstream stream, CUdeviceptr addr, long value, int flags ) { return checkResult ( cuStream WriteValue64Native ( stream, addr, value, flags ) ) ; } \n', 0.96685445) ('public static ByteBuf encode ( ByteBufAllocator allocator, int streamId, boolean fragmentFollows, boolean complete, boolean next, P ayload payload ) { return FLYWEIGHT . encode ( allocator, streamId, fragmentFollows, complete, next, 0, payload . hasMetadata ( ) ? payload . metadata ( ) . retain ( ) : null, payload . data ( ) . retain ( ) ) ; } \n', 0.9657972) ('public static < S > Stream < S > stream ( Iterable < S > input ) { return stream ( input, false ) ; } \n', 0.96461725)

('private byte [ ] decryptStream ( byte [ ] key, byte [ ] keepassFile ) throws IOException { CryptoInformation cryptoInformation = ne w CryptoInformation ( KeePassHeader . VERSION_SIGNATURE_LENGTH, keepassHeader . getMasterSeed ( ) , keepassHeader . getTransformSeed ( ) , keepassHeader . getEncryptionIV ( ) , keepassHeader . getTransformRounds ( ) , keepassHeader . getHeaderSize ( ) ) ; return decrypter . decryptDatabase ( key, cryptoInformation, keepassFile ) ; } \n', 0.96334887) ('private ByteBuf encodeReadHoldingRegisters ( ReadHoldingRegistersResponse response, ByteBuf buffer ) { buffer . writeByte ( respons e . getFunctionCode ( ) . getCode ( ) ) ; buffer . writeByte ( response . getRegisters ( ) . readableBytes ( ) ) ; buffer . writeBytes ( response . getRegisters ( ) ) ; return buffer ; } \n', 0.96252537) ('@Override public byte [ ] encode ( Endpoint endpoint ) { return ( endpoint . host ( ) + fieldDelimiter + endpoint . port ( ) + fiel dDelimiter + endpoint . weight ( ) ) . getBytes ( StandardCharsets . UTF_8 ) ; } \n', 0.96244895) ('private static InputStreamWithMetadata compressStreamWithGZIP ( InputStream inputStream ) throws SnowflakeSQLException { FileBacked OutputStream tempStream = new FileBackedOutputStream ( MAX_BUFFER_SIZE, true ) ; try { DigestOutputStream digestStream = new DigestOutputStream ( tempStream, MessageDigest . getInstance ( " " ) ) ; CountingOutputStream countingStream = new CountingOutputStream ( digestStream ) ; GZIPOutputStream gzipStream ; gzipStream = new GZIPOutputStream ( countingStream, true ) ; IOUtils . copy ( inputStream, gzipStream ) ; inputStream . close ( ) ; gzipStream . finish ( ) ; gzipStream . flush ( ) ; countingStream . flush ( ) ; return new InputStreamWithMetadata ( countingStream . getCount ( ) , Base64 . encodeAsString ( digestStream . getMessageDigest ( ) . digest ( ) ) , tempStream ) ; } catch ( IOException | NoSuchAlgorithmException ex ) { logger . error ( " " , ex ) ; throw new SnowflakeSQLException ( ex, SqlState . INTERNAL_ERROR, ErrorCode . INTERNAL_ERROR . getMessageCode ( ) , " " ) ; } } \n', 0.96195817) ('public static < U > Stream < U > stream ( final Spliterator < U > it ) { return StreamSupport . stream ( it, false ) ; } \n', 0.961 52276)

batchzise=1的结果更相关,我不知道这是为什么。我不太熟悉这个批操作的影响来自哪里。如果你找到原因请希望您的回复。

guxd commented 4 years ago

见repr_code.py第50行。根据现在的设置,chunk_size=2000000, 你的codebase需要至少达到这么多代码系统才存储你的code vector. 我怀疑你搜索的时候用的老的code vector(请确认). 解决办法:在codebase大小小于chunk_size时只存储一个code vector文件,相应的在search.py里load vector时也要做相应更改。

我这边暂时先不改,以防引起其他问题,后面时间充裕会完整修改调试。

primary-studyer commented 4 years ago

我有注意到您说的这块。 我的codebase大小是1569525,chunksize是2000000,如果nprocessed未达到chunksize会在第56行存储code vector. 我也常尝试将chunksize设置成20w,生成8个codevecs向量文件。 结果和chunksize是200w一样。 我确认我在尝试新的reprcode.py时,将之前的的向量文件删除。同时也在search.py进行相应修改。 我也尝试过在reprcode.py和seaech.py把chunksize去掉,都一次性处理。结果和上面也是一样的。 然后做了codebase 大小为1(仅包含inputStreamToString代码片段)。发现这个问题。 修改usedata的加载批次batchsize,出现这样结果。这是我处理时候的过程。感谢您之前的回复,对我帮助很大。

---Original--- From: "Xiaodong Gu"<notifications@github.com> Date: Sat, Jul 18, 2020 10:04 AM To: "guxd/deep-code-search"<deep-code-search@noreply.github.com>; Cc: "Author"<author@noreply.github.com>;"Direction"<825880441@qq.com>; Subject: Re: [guxd/deep-code-search] the reasult not well (#49)

见repr_code.py第50行。根据现在的设置,chunk_size=2000000, 你的codebase需要至少达到这么多代码系统才存储你的code vector. 我怀疑你搜索的时候用的老的code vector(请确认). 解决办法:在codebase大小小于chunk_size时只存储一个code vector文件,相应的在search.py里load vector时也要做相应更改。

我这边暂时先不改,以防引起其他问题,后面时间充裕会完整修改调试。

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub, or unsubscribe.

guxd commented 3 years ago

已解决PyTorch版的Bug, 问题在modules.py文件里的h_n = h_n.transpose(1, 0).contiguous(). 去掉这行validation 效果提升,所以当时不下心注释掉了,把这行添上就行了。 现在测试效果没问题了。