dlunion / CC4.0

Caffe for CC4.0-Windows,简单的Caffe C++接口,方便简单
162 stars 69 forks source link

为什么使用vgg或者resnet训练的时候不收敛? #6

Open how0723 opened 6 years ago

how0723 commented 6 years ago

为什么使用vgg或者resnet训练的时候不收敛?已导入预训练模型

how0723 commented 6 years ago

是分类任务,用vggface2数据集,做1000分类 ,用的是vggface2的预训练模型

dlunion commented 6 years ago

你代码有bug吧

how0723 commented 6 years ago

@dlunion 只改了输入层,代码如下

class ResnetDataLayer : public DataLayer { public: SETUP_LAYERFUNC(ResnetDataLayer);

virtual int getBatchCacheSize() {
    return 3;
}

virtual void loadBatch(Blob** top, int numTop) {

    Blob* image = top[0];
    Blob* label = top[1];

    float* image_ptr = image->mutable_cpu_data();
    float* label_ptr = label->mutable_cpu_data();
    int batch_size = image->num();
    int w = image->width();
    int h = image->height();

    for (int i = 0; i < batch_size; ++i) {
        auto& item = this->vfs_[this->cursor_++];
        Mat im = imread(item.first);
        int r = im.cols<im.rows ? im.cols : im.rows;
        im = im(cv::Rect(im.cols / 2 - r / 2, im.rows / 2 - r / 2, r, r));
        if (!im.empty()) {
            if (im.size() != Size(w, h)) resize(im, im, Size(w, h));

            //im.convertTo(im, CV_32F, 1 / 127.5, -1);
            im.convertTo(im, CV_32F, 1.0, 0);

            Mat ms[3];
            float* check = image_ptr;
            for (int c = 0; c < 3; ++c) {
                ms[c] = Mat(h, w, CV_32F, image_ptr);
                image_ptr += w * h;
            }

            split(im, ms);
            CV_Assert((float*)ms[0].data == check);

            //memcpy(label_ptr, &item.second, sizeof(float));
            *label_ptr = item.second;
            label_ptr++;
        }
        else {
            i--;
        }

        if (this->cursor_ == this->vfs_.size()) {
            this->cursor_ = 0;
            std::random_shuffle(vfs_.begin(), vfs_.end());
        }
    }
}

void preperData() {
    string fileList;
    if (this->phase_ == PhaseTest)
        fileList = "./to_test.txt";
    else
        fileList = "./to_train.txt";

    fstream input(fileList, ios::in);
    string line;
    while (std::getline(input, line))
    {
        int n = line.find("\t");
        std::string filename = line.substr(0, n);
        int label = atoi((const char*)line.substr(n + 1).c_str());
        vfs_.push_back(make_pair(filename, label));
    }
    input.close();

    CV_Assert(vfs_.size() > 0);
    this->cursor_ = 0;
    std::random_shuffle(vfs_.begin(), vfs_.end());
}

virtual void setup(const char* name, const char* type, const char* param_str, int phase, Blob** bottom, int numBottom, Blob** top, int numTop) {
    map<string, string> param = parseParamStr(param_str);
    //this->num_ = getParamInt(param, "num", 6);
    this->num_ = 1;
    const int batch_size = getParamInt(param, "batch_size");
    this->phase_ = phase;

    top[0]->Reshape(batch_size, 3, getParamInt(param, "height"), getParamInt(param, "width"));
    top[1]->Reshape(batch_size, this->num_, 1, 1);
    preperData();

    __super::setup(name, type, param_str, phase, bottom, numBottom, top, numTop);
}

private: vector<pair<string, float >> vfs; int cursor; int phase; int num; };