Open how0723 opened 6 years ago
是分类任务,用vggface2数据集,做1000分类 ,用的是vggface2的预训练模型
你代码有bug吧
@dlunion 只改了输入层,代码如下
class ResnetDataLayer : public DataLayer { public: SETUP_LAYERFUNC(ResnetDataLayer);
virtual int getBatchCacheSize() {
return 3;
}
virtual void loadBatch(Blob** top, int numTop) {
Blob* image = top[0];
Blob* label = top[1];
float* image_ptr = image->mutable_cpu_data();
float* label_ptr = label->mutable_cpu_data();
int batch_size = image->num();
int w = image->width();
int h = image->height();
for (int i = 0; i < batch_size; ++i) {
auto& item = this->vfs_[this->cursor_++];
Mat im = imread(item.first);
int r = im.cols<im.rows ? im.cols : im.rows;
im = im(cv::Rect(im.cols / 2 - r / 2, im.rows / 2 - r / 2, r, r));
if (!im.empty()) {
if (im.size() != Size(w, h)) resize(im, im, Size(w, h));
//im.convertTo(im, CV_32F, 1 / 127.5, -1);
im.convertTo(im, CV_32F, 1.0, 0);
Mat ms[3];
float* check = image_ptr;
for (int c = 0; c < 3; ++c) {
ms[c] = Mat(h, w, CV_32F, image_ptr);
image_ptr += w * h;
}
split(im, ms);
CV_Assert((float*)ms[0].data == check);
//memcpy(label_ptr, &item.second, sizeof(float));
*label_ptr = item.second;
label_ptr++;
}
else {
i--;
}
if (this->cursor_ == this->vfs_.size()) {
this->cursor_ = 0;
std::random_shuffle(vfs_.begin(), vfs_.end());
}
}
}
void preperData() {
string fileList;
if (this->phase_ == PhaseTest)
fileList = "./to_test.txt";
else
fileList = "./to_train.txt";
fstream input(fileList, ios::in);
string line;
while (std::getline(input, line))
{
int n = line.find("\t");
std::string filename = line.substr(0, n);
int label = atoi((const char*)line.substr(n + 1).c_str());
vfs_.push_back(make_pair(filename, label));
}
input.close();
CV_Assert(vfs_.size() > 0);
this->cursor_ = 0;
std::random_shuffle(vfs_.begin(), vfs_.end());
}
virtual void setup(const char* name, const char* type, const char* param_str, int phase, Blob** bottom, int numBottom, Blob** top, int numTop) {
map<string, string> param = parseParamStr(param_str);
//this->num_ = getParamInt(param, "num", 6);
this->num_ = 1;
const int batch_size = getParamInt(param, "batch_size");
this->phase_ = phase;
top[0]->Reshape(batch_size, 3, getParamInt(param, "height"), getParamInt(param, "width"));
top[1]->Reshape(batch_size, this->num_, 1, 1);
preperData();
__super::setup(name, type, param_str, phase, bottom, numBottom, top, numTop);
}
private: vector<pair<string, float >> vfs; int cursor; int phase; int num; };
为什么使用vgg或者resnet训练的时候不收敛?已导入预训练模型