nyk510 / vivid

Support Tools for Machine Learning VIVIDLY
Other
40 stars 2 forks source link

[feature] Support SWEM #12

Open nyk510 opened 4 years ago

nyk510 commented 4 years ago

class TextSWEMAtom(AbstractMergeAtom):
    merge_key = 'hoge'

    def __init__(self, n_components=32, agg='max', text_column='body'):
        self.n_components = n_components
        self.agg = agg
        self.text_column = text_column
        self.cache_path = os.path.join(CACHE_DIR, f'kiji_text_{text_column}.joblib')

        super(TextSWEMAtom, self).__init__()

    def __str__(self):
        s = super(TextSWEMAtom, self).__str__()
        s = s + f' {self.agg}@{self.text_column}_n_components={self.n_components}'
        return s

    def read_outer_dataframe(self):
        return read_all()

    def fit(self, input_df: pd.DataFrame, y=None):
        self.is_train_context = y is not None
        return self

    def load_parsed_docs(self):

        if os.path.exists(self.cache_path):
            return joblib.load(self.cache_path)

        df = self.df_outer
        text_data = df[self.text_column]

        with timer(logger=logger, format_str=self.text_column + ' parse context {:.3f}[s]'):
            title_docs = [safe_normalize(d) for d in text_data]
            title_docs = np.array(title_docs)
            idx_none = title_docs == None
            title_docs = title_docs[~idx_none]
            parser = DocumentParser()
            parsed = [parser.call(s) for s in title_docs]

            swem = SWEM(W2V.load_model(), aggregation=self.agg)
            x = swem.transform(parsed)

        joblib.dump([x, idx_none], self.cache_path)
        return x, idx_none

    def generate_outer_feature(self):
        with timer(logger, format_str=self.text_column + ' load {:.3f}[s]'):
            x, idx_none = self.load_parsed_docs()

        if self.is_train_context:
            clf_pca = PCA(n_components=self.n_components)
            clf_pca.fit(x)
            self.clf_pca_ = clf_pca

        transformed = self.clf_pca_.transform(x)
        retval = np.zeros(shape=(len(self.df_outer), self.n_components))
        retval[~idx_none] = transformed
        out_df = pd.DataFrame(retval, columns=[f'swem_{self.agg}_{self.text_column}_' + str(i) for i in
                                               range(self.n_components)])
        return out_df
nyk510 commented 4 years ago

大体において作成は滅茶コスト高いので cacheable にして cache dir に保存するようにしても良さそう