Closed llxxxll closed 5 years ago
Reproduced.
After go over the code, this is more likely to be a feature request, to automatically support VARCHAR
or TEXT
field
In Text classification problems, model input like news title, news content, news keywords are a sequence of text, so we need to support typical models like:
sequence_pooling
to put sequence to dense feature.
2. CNN text classification which uses operators like sequence_conv
Things in tensorflow supporting sequence seems not stable yet, like: https://www.tensorflow.org/api_docs/python/tf/contrib/feature_column/sequence_input_layer and https://www.tensorflow.org/api_docs/python/tf/contrib/estimator/RNNEstimator. Yet we can still apply these "contrib" features to simply make SQLFlow support general sequence of text.
For the most simple case, we only implement No.1 for now.
Preprocessing
CREATE TABLE `train` ( `id` bigint(20) NOT NULL, `class_id` int(3) NOT NULL, `class_name` varchar(100) NOT NULL, `news_title` varchar(255) NOT NULL, `news_keywords` varchar(255) NOT NULL) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci;COMMIT;
/var/lib/mysql-files
)
load data local infile '/var/lib/mysql-files/toutiao_cat_data.txt' into table train CHARACTER SET utf8 fields terminated by '_!_' lines terminated by "\n";
vocab
and write integers to mysql:# -*- coding: utf-8 -*-
# encoding=utf-8
import sys
import codecs
import mysql.connector
import jieba
stdout = codecs.getwriter('utf-8')(sys.stdout)
def prepare_vocab():
mydb = mysql.connector.connect(
host="127.0.0.1",
user="root",
passwd="root",
database="toutiao",
use_unicode=True,
charset="utf8",
)
mycursor = mydb.cursor()
mycursor.execute("SELECT * FROM train")
voc = dict()
voc["<unk>"] = 0
wordid = 0
max_seq_length = 0
while True:
x = mycursor.fetchone()
if x is None:
break
seg_list = jieba.cut(x[3])
seg_len = 0
for w in seg_list:
if not w in voc:
wordid += 1
voc[w] = wordid
seg_len += 1
if seg_len > max_seq_length:
max_seq_length = seg_len
mycursor.reset()
print("total vocab size: ", len(voc), " max seq length: ", max_seq_length)
with open("vocab.txt", "w", encoding="utf-8") as fn:
for w in voc:
fn.write("%s\t%d\n" % (w, voc[w]))
def load_vocab(file_path):
voc = dict()
with open(file_path, "r") as fn:
for l in fn:
l_strip = l[0:-1]
try:
w, idx = l_strip.split("\t")
voc[w] = idx
except:
print("skip vocab: ", l)
return voc
def write_prepared_table(max_seq_length):
# write to table train_processed, string fields should be encoded like:
# "9,100,33,21,0,0,0,0" padding to max seq length with 0
voc = load_vocab("vocab.txt")
conn_write = mysql.connector.connect(
host="127.0.0.1",
user="root",
passwd="root",
database="toutiao",
use_unicode=True,
charset="utf8",
)
conn_read = mysql.connector.connect(
host="127.0.0.1",
user="root",
passwd="root",
database="toutiao",
use_unicode=True,
charset="utf8",
)
mycursor = conn_read.cursor()
mycursor.execute("SELECT * FROM train")
write_cursor = conn_write.cursor()
while True:
x = mycursor.fetchone()
if x is None:
break
title_tensor = []
title_seg = jieba.cut(x[3])
for w in title_seg:
if w in voc:
title_tensor.append(str(voc[w]))
else:
title_tensor.append("0")
# padding
while len(title_tensor) < max_seq_length:
title_tensor.append("0")
sql = """INSERT INTO train_processed (id, class_id, class_name, news_title, news_keywords)
VALUES (%d, %d, "%s", "%s", "%s")""" % (x[0], x[1], x[2], ",".join(title_tensor), x[4])
write_cursor.execute(sql)
conn_write.commit()
mycursor.reset()
conn_read.close()
if __name__ == "__main__":
prepare_vocab()
# 92 is max_seq_length the above line outputs
write_prepared_table(92)
codegen.go
to support reading preprocessed data [DOING]
Dataset: chinese_news_dataset
SQLFlow SQL:
Return logs:
Table info:
Simple data: