sugarme / gotch

Go binding for Pytorch C++ API (libtorch)
Apache License 2.0
571 stars 45 forks source link

Float64Values() shows an error 'Unsupported Go type: []float64' #97

Closed luxiant closed 1 year ago

luxiant commented 1 year ago
// function that gets a single row of the gota dataframe as an input, process the sentence of column 'text' and return struct
func bertSentimentProcess(dataframe dataframe.DataFrame) sentimentRow {
    var torchResult *ts.Tensor
    ts.NoGrad(func() {
        torchResult, _, _ = useModels.bertModel.ForwardT(
            processSentenceIntoInput(dataframe.Col("text").Records()[0]),
            ts.None,
            ts.None,
            ts.None,
            ts.None,
            false,
        )
    })
    categoryProb := torchResult.MustSoftmax(-1, gotch.Double, true).Float64Values()
    var sentiment string
    switch {
    case categoryProb[0] > categoryProb[1] && categoryProb[0] > categoryProb[2]:
        sentiment = "long"
    case categoryProb[1] > categoryProb[0] && categoryProb[1] > categoryProb[2]:
        sentiment = "neutral"
    default:
        sentiment = "short"
    }
    return sentimentRow{
        post_num:  dataframe.Col("post_num").Records()[0],
        time:      dataframe.Col("time").Records()[0],
        text:      dataframe.Col("text").Records()[0],
        long:      categoryProb[0],
        neutral:   categoryProb[1],
        short:     categoryProb[2],
        sentiment: sentiment,
    }
}

// function that converts sentence into an input tensor
func processSentenceIntoInput(sentence string) *ts.Tensor {
    sentence = strings.ReplaceAll(sentence, "- dc official App", " ")
    sentence = strings.ReplaceAll(sentence, "ㅋ", " ")
    sentence = strings.ReplaceAll(sentence, "\n", " ")
    sentence = strings.ReplaceAll(sentence, "ㅡ", " ")
    reg, _ := regexp.Compile("[^가-힣ㄱ-ㅎㅏ-ㅣa-zA-Z0-9\\-\\%\\?\\.]")
    sentence = reg.ReplaceAllString(sentence, " ")
    words := strings.Split(sentence, " ")
    n := 0
    for _, word := range words {
        if word != "" {
            words[n] = word
            n++
        }
    }
    sentence = strings.Join(words[:n], " ")
    finalEncode, _ := useModels.tokenizer.Encode(
        tokenizer.NewSingleEncodeInput(
            tokenizer.NewInputSequence(sentence),
        ),
        true,
    )
    switch {
    case finalEncode.Len() > maxLength:
        finalEncode, _ = finalEncode.Truncate(maxLength, 2)
    case finalEncode.Len() < maxLength:
        finalEncode = &tokenizer.PadEncodings(
            []tokenizer.Encoding{*finalEncode},
            *paddingParameter,
        )[0]
    default:
    }
    var tokInput = make([]int64, maxLength)
    for i := 0; i < len(finalEncode.Ids); i++ {
        tokInput[i] = int64(finalEncode.Ids[i])
    }
    return ts.MustStack(
        []ts.Tensor{*ts.TensorFrom(tokInput)},
        0,
    ).MustTo(device, true)
}

I'm working with my project and hit by this error during debugging.

''' root@codespaces-e73b16:/workspaces/KoBERT# go run main.go 2023/03/05 07:00:02 INFO: CachedDir="/root/.cache/transformer" Successfully loaded model 0% | | (0/100, 0 it/hr) [0s:0s]2023/03/05 07:00:06 Unsupported Go type: []float64 exit status 1 '''

I searched all of the variables throughout my code with []float64 type, but in this part of my code the line inside of the function 'bertSentimentProcess'

categoryProb := torchResult.MustSoftmax(-1, gotch.Double, true).Float64Values()

is the only one in which the variable takes []float64 type. I'm trying to figure out why the bug happens but still can't get it. []float64 is a supported Go type, so I thought this error should not happen. I first assumed that this is a language error, but even after reinstalling Go 1.19 I'm still getting a same error.

sugarme commented 1 year ago

@luxiant,

can you try to print out tensors to check? Something like this:

fmt.Printf("torchResult: %i\n", torchResult)
sm := torchResult.MustSoftmax(-1, gotch.Double, true)
fmt.Printf("softmax tensor: %i\n", sm)

categoryProb := sm.Float64Values() 
fmt.Printf("categoryProb: %v\n", categoryProb) // if error occurs at above line, won't see this log.

There are some memory leakages in your code, but will discuss later.

Please report the logs. Thanks.

luxiant commented 1 year ago

Uhhhhhh.....I feel bad for bothering you with this. Eventually, I finally found the cause after one day and a half of troubleshooting and that was because of my silly mistake.

The problem was not there that I pasted, but was in the process for loading my pretrained bert model. The correct way to load bertconfig file that you've suggested in the example was

bertConfig, _ := bert.ConfigFromFile("model/bert_config.json") // load bert config file
    var dummyLabelMap map[int64]string = make(map[int64]string)
    dummyLabelMap[0] = "long"
    dummyLabelMap[1] = "neutral"
    dummyLabelMap[2] = "short"
    bertConfig.Id2Label = dummyLabelMap
    bertConfig.OutputAttentions = true
    bertConfig.OutputHiddenStates = true

But I mistakenly erased some part and tried to run this code.

bertConfig, _ := bert.ConfigFromFile("model/bert_config.json") // load bert config file
    var dummyLabelMap map[int64]string = make(map[int64]string)
    bertConfig.OutputAttentions = true
    bertConfig.OutputHiddenStates = true

I corrected this and the problem solved.