Open andreassot10 opened 4 years ago
Hey, Thanks for the analysis. I agree that a lot of things are suboptimal here. The goal for this first draft of a vectorizer was to see whether a) people would use it at all and b) how / if we can include text mining stuff at all.
I agree, that creating a sparse matrix
/ dfm
instead of a data.table might be the better solution.
You have to consider though, that we are somewhat limited by mlr3
's backends, do you think a
Sparse Matrix Backend would solve this, by
i.e. converting dfm
-> sparse
for the PipeOp and sparse
-> dfm
for training again?
(If dfm is not just a wrapper around a sparse matrix anyways?)
If this is not an option, we would have to consider writing a DataBackendDFM
, which would make things only a little more complicated. In this case, we should think about creating something like mlr3nlp
though.
Is this something you would like to help with? My time is currently limited, but I will gladly help by discussing / reviewing.
Thanks @pfistfl. I think that I could help with this- or at least I could try. I'm doing this stuff for work, which means that I cannot allocate all of my time to this, but a few hours a week in the next few weeks would be a reasonable allocation.
Cool! A first step would be to see whether we can leverage the existing SparseMatrix Backend. I think, if we simply benchmark how conversion ( dfm -> sparse for the PipeOp and sparse -> dfm) compares to using a data. table, we might be able to pin-point this already.
Sounds reasonable. I'll take a look this week.
@pfistfl , I need a little bit of support on this please, as I'm pretty new to the R6 stuff. So I'm trying to understand what the backends are and how they work. Can you let me know if these statements are correct:
mlr3pipelines::PipeOpTextVectorizer
uses mlr3::DataBackendDataTable
.mlr3::DataBackendDataTable
is what converts the data inside mlr3pipelines::PipeOpTextVectorizer
into data.table
format.mlr3::DataBackendDataTable
is called into mlr3pipelines::PipeOpTextVectorizer
? Searching for 'backend' in mlr3pipelines::PipeOpTextVectorizer.R
doesn't return any relevant results.mlr3::DataBackendDataTable
with mlr3::DataBackendMatrix
inside mlr3pipelines::PipeOpTextVectorizer
, right?It'd be great if you could guide me a bit so that I advance this as fast as I can.
Thanks
PipeOpTextvectorizer
inherits from PipeOpTaskPreproc
.
This pipeop has methods .train
, .train_task
and .train_dt
.
Each of those methods calls the next one and ensures that the output of the child classes conforms to a class, in case of overwriting _dt
-- to a data.table.
By switching from _dt
to _task
we basically would have to make sure to return a Task
instead of a data.table
, leaving us the choice of a backend.
Yes, but this is not necessary
No. as.data.frame(map(colwise_results, "matrix"))
converts into a data.frame
/ data.table
. .train_dt
ensures this is a data.table
.
This happens inside PipeOpTaskPreproc
(train_task
which internally calls train_dt
). AFAIK cbind
converts to DataBackendXXX
automatically depending on what you cbind.
train_dt
with train_task
, copy over a little code from train_task
and cbind a sparse matrix.Thanks so much for the detailed response and ideas, @pfistfl.
I'm almost convinced that the only way forward would be to break completely free from data.table
/data.frame
/Matrix
formats and create a DataBackendDFM
backend. The advantage would be immensely faster computation times in text classification problems. The tradeoff is that this backend would be exclusive to quanteda.textmodels
, as this package and quanteda
are the only ones that work with dfm
objects (to the best of my knowledge). However, the learners in quanteda.textmodels
are specialised text classification models that can naturally handle large sparse matrices, which is a great plus. And they're pretty fast.
I created a mlr3
version of quanteda.textmodels::textmodel_nb()
here. This new learner, classif.textmodel_nb
, takes the data.table
data from PipeOpTextVectorizer
and converts them to dfm
before fitting the model.
What follows is a benchmark exercise showing how converting the data from dfm
to data.table
in PipeOpTextVectorizer
, and then from data.table
to dfm
in classif.textmodel_nb
is grossly inefficient:
library(mlr3)
library(mlr3pipelines)
library(quanteda)
library(quanteda.textmodels)
library(tidymodels)
library(microbenchmark)
# Movie corpus data in 'corpus' format
corp_movies <- data_corpus_moviereviews
summary(corp_movies, 5)
class(corp_movies)
# Movie corpus data in 'data frame' format. Will be passed to mlr3's task function
corp_movies_df <- convert(corp_movies, to = 'data.frame') %>%
select(sentiment, text) %>%
rename(target_variable = sentiment)
task <- TaskClassif$new("task_id", corp_movies_df,
target = 'target_variable')
task$col_roles$stratum <- 'target_variable'
po_text <- po(
"textvectorizer",
param_vals = list(
stopwords_language = "en",
scheme_df = 'inverse',
remove_punct = TRUE,
remove_symbols = TRUE,
remove_numbers = TRUE
),
affect_columns = selector_name('text')
)
mnb <- lrn('classif.textmodel_nb', predict_type = 'prob')
learners <- list(mnb)
names(learners) <- sapply(learners, function(x) x$id)
# Our pipeline
graph <-
po_text %>>%
po("branch", names(learners)) %>>% # Branches will be as many as the learners (one in this example)
gunion(unname(learners)) %>>%
po("unbranch") # Gather results for individual learners into a results table
graph$plot() # Plot pipeline
pipe <- GraphLearner$new(graph) # Convert pipeline to learner
pipe$predict_type <- 'prob'
# Parameter set
ps_text <- ParamSet$new(list(
ParamInt$new('textvectorizer.n', lower = 2, upper = 3))
)
param_set <- ParamSetCollection$new(list(
ParamSet$new(list(pipe$param_set$params$branch.selection$clone())),
ps_text
))
# Set up tuning instance
instance <- TuningInstanceSingleCrit$new(
task = task,
learner = pipe,
resampling = rsmp('cv', folds = 2),
measure = msr('classif.bbrier'),
search_space = param_set,
terminator = trm("evals", n_evals = 5),
store_models = TRUE)
tuner <- TunerRandomSearch$new()
tuner$optimize(instance)
The settings above, with just 2 CV folds and 5 evaluations, almost fried a 16GB Mac. Note below how long each batch takes (batch 1: 5 minutes; batch 2: 13 minutes; batch 3: 6 minutes; batches 4-5: I killed the process):
INFO [10:15:31.030] Starting to optimize 2 parameter(s) with '<OptimizerRandomSearch>' and '<TerminatorEvals>'
INFO [10:15:31.102] Evaluating 1 configuration(s)
INFO [10:15:31.647] Benchmark with 2 resampling iterations
INFO [10:15:31.649] Applying learner 'textvectorizer.branch.classif.textmodel_nb.unbranch' on task 'task_id' (iter 1/2)
INFO [10:17:52.365] Applying learner 'textvectorizer.branch.classif.textmodel_nb.unbranch' on task 'task_id' (iter 2/2)
INFO [10:20:34.627] Finished benchmark
INFO [10:20:34.884] Result of batch 1:
INFO [10:20:34.889] branch.selection textvectorizer.n classif.bbrier uhash
INFO [10:20:34.889] classif.textmodel_nb 2 0.1730278 3cb05f82-f57f-4040-b304-5127a2563deb
INFO [10:20:35.718] Evaluating 1 configuration(s)
INFO [10:20:36.370] Benchmark with 2 resampling iterations
INFO [10:20:36.372] Applying learner 'textvectorizer.branch.classif.textmodel_nb.unbranch' on task 'task_id' (iter 1/2)
INFO [10:26:49.635] Applying learner 'textvectorizer.branch.classif.textmodel_nb.unbranch' on task 'task_id' (iter 2/2)
INFO [10:33:48.536] Finished benchmark
INFO [10:33:49.863] Result of batch 2:
INFO [10:33:49.874] branch.selection textvectorizer.n classif.bbrier uhash
INFO [10:33:49.874] classif.textmodel_nb 3 0.2001218 73e5c904-7f27-4e8a-bfcf-157d90340f20
INFO [10:33:49.968] Evaluating 1 configuration(s)
INFO [10:33:52.044] Benchmark with 2 resampling iterations
INFO [10:33:52.047] Applying learner 'textvectorizer.branch.classif.textmodel_nb.unbranch' on task 'task_id' (iter 1/2)
INFO [10:36:44.746] Applying learner 'textvectorizer.branch.classif.textmodel_nb.unbranch' on task 'task_id' (iter 2/2)
INFO [10:39:36.997] Finished benchmark
INFO [10:39:39.229] Result of batch 3:
INFO [10:39:39.234] branch.selection textvectorizer.n classif.bbrier uhash
INFO [10:39:39.234] classif.textmodel_nb 2 0.1730278 6c281af7-a9a0-4764-a9b5-78dd3229efdc
INFO [10:39:39.277] Evaluating 1 configuration(s)
INFO [10:39:39.801] Benchmark with 2 resampling iterations
INFO [10:39:39.804] Applying learner 'textvectorizer.branch.classif.textmodel_nb.unbranch' on task 'task_id' (iter 1/2)
So let's try something else: Train the text vectorizer on the whole task and then pass the new dataset to the learner (the text feature will now be a document-feature matrix in data.table
format):
microbenchmark(
{
dfm_data_table <- po_text$train(list(task))[[1]]$data()
task1 <- TaskClassif$new("task_id", dfm_data_table, target = 'target_variable')
mnb_mlr3 <- mnb$train(task1)
mnb_mlr3$predict(task1)
},
{
dfm_corpus <- dfm(corp_movies)
mnb_quanteda <- textmodel_nb(y = dfm_corpus@docvars$sentiment, x = dfm_corpus)
predict(mnb_quanteda, dfm_corpus)
},
times = 1
)
# Result in seconds:
# min lq mean median uq max neval
# 25.587085 25.587085 25.587085 25.587085 25.587085 25.587085 1
# 1.691156 1.691156 1.691156 1.691156 1.691156 1.691156 1
Let's also break down the above process to find the most time-consuming sub-processes:
microbenchmark(
{
dfm_data_table <- po_text$train(list(task))[[1]]$data()
task1 <- TaskClassif$new("task_id", dfm_data_table, target = 'target_variable')
},
mnb_mlr3 <- mnb$train(task1),
mnb_mlr3$predict(task1),
dfm_corpus <- dfm(corp_movies),
mnb_quanteda <- textmodel_nb(y = dfm_corpus@docvars$sentiment, x = dfm_corpus),
predict(mnb_quanteda, dfm_corpus),
times = 1
)
# Results in MILIseconds:
min lq mean median uq max neval
# 18631.169278 18631.169278 18631.169278 18631.169278 18631.169278 18631.169278 1
# 4643.412090 4643.412090 4643.412090 4643.412090 4643.412090 4643.412090 1
# 1718.631021 1718.631021 1718.631021 1718.631021 1718.631021 1718.631021 1
# 1598.421712 1598.421712 1598.421712 1598.421712 1598.421712 1598.421712 1
# 38.023609 38.023609 38.023609 38.023609 38.023609 38.023609 1
# 7.793736 7.793736 7.793736 7.793736 7.793736 7.793736 1
So the text vectorizer PipeOpTextVectorizer
is the most time-consuming process (~18.6 s). Running the learner via the my mlr3
implementation (classif.textmodel_nb
) is the second most time-consuming process (~4.6 s). Compare this to the miliseconds it took for the quanteda.textmodels
implementation to train the learner. The difference in the predict times between the two implementations is also impressive.
So, my conclusion is: if there's demand from mlr3
users for a better implementation of text classification, you may want to consider creating a DataBackendDFM
backend at some point.
I don't disagree.
What is the exact difference between a dfm
and a sparse matrix?
I mean e.g. xgboost
works on sparse matrices afaik, so returning a sparse matrix might for example be what is required if you want to do text analysis with it. Thus relying on a more widely used format might be useful.
EDIT
So what my proposal was is basically to convert to a format that is more widely used, so we can use things for more purposes.
For sparse matrix, at the end of the TextVectorizer we convert to a sparseMatrix
and before going into textmining, we convert back to a dfm
. If this is also inefficient, we should perhaps really think about a DFM backend.
So a dfm
is, according to its authors:
"[...] a type of Matrix-class object with additional slots, described below [in dfm-class {quanteda}
]. quanteda uses two subclasses of the dfm class, depending on whether the object can be represented by a sparse matrix, in which case it is a dfm class object, or if dense, then a dfmDense object."
You are right that it's a Matrix
in its core and I did notice that function .transform_tfidf
in PipeOpTextVectorizer.R
returns a sparse matrix. But I'm afraid that I still don't understand how we can have PipeOpTextVectorizer
return the data in Matrix
than data.table
format, even if we use .train_task
instead of .train_dt
? Doesn't .train_task
internally convert the data to data.table
anyway? You mention earlier that "AFAIK cbind
converts to DataBackendXXX
automatically depending on what you cbind." So, let's try cbind
a matrix
to a task (hoping to have the task convert the data to matrix
):
task_iris <- tsk('iris')
task_iris$cbind(data.matrix(iris))
#Error in assert_backend(backend) :
# Assertion on 'backend' failed: Must inherit from class 'DataBackend', but has class 'matrix'.
It won't even let us do it. Same for
task_iris <- tsk('iris')
task_iris$cbind(Matrix(data.matrix(iris)))
#Error in assert_backend(backend) :
# Assertion on 'backend' failed: Must inherit from class 'DataBackend', but has class 'dgeMatrix'.
Also, re your comment "I mean e.g. xgboost
works on sparse matrices afaik, so returning a sparse matrix might for example be what is required if you want to do text analysis with it.": indeed, that's the format xgboost
works with. But note that the mlr3
implementation of xgboost
takes that data in data.table
format, then converts it to matrix
and then to xgb.DMatrix
:
data = xgboost::xgb.DMatrix(data = data.matrix(data), label = label)
Unless I'm missing something (and I probably am), the backend seems to always DataBackendDataTable
, not DataBackendDataMatrix
in all pipeops and learners.
I do see the value of having as few dependencies as possible, and why you'd rather avoid developing backends that are hard or impossible to generalize (e.g. a DFM backend). It's great that mlr3
internally converts the data to the format required by each algorithm, so we need to think of a clever way to do that for the quanteda.textmodels
learners with as few data conversions as possible.
I'm sorry that I don't really have a better proposal at this point, but I'm still quite confused (as you may have noticed!).
EDIT
Look how much faster it is to convert a Matrix
to dfm
than it is for a matrix
or data.frame
:
library(quanteda)
library(microbenchmark)
# Movie corpus data in 'corpus' format
corp_movies <- data_corpus_moviereviews
summary(corp_movies, 5)
# Convert corpus in different data formats
df <- convert(dfm(corp_movies), 'data.frame')
dm <- data.matrix(df)
dM <- Matrix(dm)
# Convert back to dfm to measure time needed
microbenchmark(as.dfm(df), as.dfm(dm), as.dfm(dM), times = 1)
#Unit: milliseconds
# expr min lq mean median uq max neval
# as.dfm(df) 46917.033290 47739.43571 49116.06864 48561.83812 50215.58632 51869.33451 3
# as.dfm(dm) 99.496684 104.69127 119.84104 109.88586 130.01322 150.14058 3
# as.dfm(dM) 9.378019 10.42502 13.41174 11.47202 15.42859 19.38517 3
I think the benchmark is not really what we want to measure. What we are interested in is dfm -> sparsematrix -> dfm vs. dfm -> data.frame/table -> dfm
Doesn't .train_task internally convert the data to data.table anyway? No, we can decide what it does. We can basically use any available backend (write a DFM Backend or use a Matrix Backend).
What works is:
library(Matrix)
task_iris$cbind(as_data_backend(Matrix(as.matrix(iris[,2:3]))))
Unless I'm missing something (and I probably am), the backend seems to always DataBackendDataTable, not DataBackendDataMatrix in all pipeops and learners. Yes, this is the go-to, for most reasons. But the Task stores the data in whatever format the Backend specifies and unless we explicitly change the data, this is not changed.
m = Matrix(sample(0:1, c(0.99,0.01), size = 150*10^4, replace = TRUE), nrow=150)
colnames(m) = paste0("x", 1:10^4)
t = tsk("iris")
t$cbind(as_data_backend(m))
t
now contains a reference to a data.table
Backend (the iris feature) and a Matrix
backend (our matrix).
This is now only converted to a data.table
if we call t$data()
.
I wrote a sparse PCA some time ago, it uses the sparse data from a Task: https://github.com/mlr-org/mlr3pipelines/blob/master/attic/PipeOpSparsePCA.R
EDIT
There seems to be a minor problem currently, with Task
only allowing to return a data.table
and no sparse format, but this can be circumvented by directly accessing the backend (see sparsePCA). Would hope we can get this solved aswell, though.
Thanks, things are much clearer now.
As it turns out, it's the conversion from dfm
to matrix
with quanteda::convert
that slows things down in PipeOpTextVectorizer
:
Converting the matrix
to a data.frame
(data.table
backend) a few lines later
also adds an amount of inefficiency relative to using a Matrix
backend- but not too much compared to the conversion from dfm
to matrix
.
So there's two things to be measured (in terms of processing time) here:
dfm
to matrix
vs. don't convert dfm
to matrix
.matrix
to data.frame
vs. convert matrix
to Matrix
.dfm
to matrix
vs. don't convert dfm
to matrix
.library(mlr3verse)
library(mlr3learners)
library(mlr3pipelines)
library(mlr3misc)
library(quanteda)
library(quanteda.textmodels)
library(dplyr)
library(microbenchmark)
library(R6)
library(Matrix)
library(checkmate)
library(paradox)
# Movie corpus data in 'corpus' format
corp_movies <- data_corpus_moviereviews
summary(corp_movies, 5)
class(corp_movies)
corp_movies_df <- convert(corp_movies, to = 'data.frame') %>%
select(sentiment, text)
task <- TaskClassif$new('movies', corp_movies_df,
target = 'sentiment')
# Grab the functions we need from PipeOpTextVectorizer
transform_tokens = function(text) {
corpus = corpus(text)
# tokenize
tkn = tokens(corpus)
tokens_ngrams(tkn)
}
transform_bow = function(tkn, trim) {
remove = stopwords::stopwords(source = "smart")
# document-feature matrix
tdm = quanteda::dfm(x = tkn, remove = remove)
tdm
}
transform_tfidf = function(tdm, docfreq) {
if (!quanteda::nfeat(tdm)) return(tdm)
# code copied from quanteda:::dfm_tfidf.dfm (adapting here to avoid train/test leakage)
x = quanteda::dfm_weight(x = tdm)
v = docfreq
j = methods::as(x, "dgTMatrix")@j + 1L
x@x = x@x * v[j]
x
}
dt = task$data()[, -1]
to_matrix <- function(column) {
tkn = transform_tokens(column)
tdm = transform_bow(tkn, trim = TRUE) # transform to BOW (bag of words), return term count matrix
state = list(
tdm = quanteda::dfm_subset(tdm, FALSE), # empty tdm so we have vocab of training data
docfreq = quanteda::docfreq(tdm)
)
tdm = quanteda::convert(transform_tfidf(tdm, state$docfreq), "matrix")
tdm
}
to_matrix_not <- function(column) {
tkn = transform_tokens(column)
tdm = transform_bow(tkn, trim = TRUE) # transform to BOW (bag of words), return term count matrix
state = list(
tdm = quanteda::dfm_subset(tdm, FALSE), # empty tdm so we have vocab of training data
docfreq = quanteda::docfreq(tdm)
)
#tdm = quanteda::convert(transform_tfidf(tdm, state$docfreq), "matrix")
tdm
}
microbenchmark(
lapply(dt, to_matrix),
lapply(dt, to_matrix_not),
times = 2
)
#Unit: seconds
# expr min lq mean median uq max neval
# lapply(dt, to_matrix) 17.555592 17.555592 20.054524 20.054524 22.553456 22.553456 2
# lapply(dt, to_matrix_not) 3.881944 3.881944 4.593245 4.593245 5.304546 5.304546 2
The difference in processing time is massive. A dfm
backend would solve this, but only for quanteda::textmodels
.
matrix
to data.frame
vs. convert matrix
to Matrix
.I have modified PipeOpTextVectorizer
to return a task
instead of a data.frame
, with a Matrix
backend:
#' @title PipeOpTextVectorizer
#'
#' @usage NULL
#' @name mlr_pipeops_textvectorizer
#' @format [`R6Class`] object inheriting from [`PipeOpTaskPreproc`]/[`PipeOp`].
#'
#' @description
#' Computes a bag-of-word representation from a (set of) columns.
#' Columns of type `character` are split up into words.
#' Uses the [`quanteda::dfm()`][quanteda::dfm],
#' [`quanteda::dfm_trim()`][quanteda::dfm_trim] from the 'quanteda' package.
#' TF-IDF computation works similarly to [`quanteda::dfm_tfidf()`][quanteda::dfm_tfidf]
#' but has been adjusted for train/test data split using [`quanteda::docfreq()`][quanteda::docfreq]
#' and [`quanteda::dfm_weight()`][quanteda::dfm_weight]
#'
#' In short:
#' * Per default, produces a bag-of-words representation
#' * If `n` is set to values > 1, ngrams are computed
#' * If `df_trim` parameters are set, the bag-of-words is trimmed.
#' * The `scheme_tf` parameter controls term-frequency (per-document, i.e. per-row) weighting
#' * The `scheme_df` parameter controls the document-frequency (per token, i.e. per-column) weighting.
#'
#' Parameters specify arguments to quanteda's `dfm`, `dfm_trim`, `docfreq` and `dfm_weight`.
#' What belongs to what can be obtained from each params `tags` where `tokenizer` are
#' arguments passed on to [`quanteda::dfm()`][quanteda::dfm].
#' Defaults to a bag-of-words representation with token counts as matrix entries.
#'
#' In order to perform the *default* `dfm_tfidf` weighting, set the `scheme_df` parameter to `"inverse"`.
#' The `scheme_df` parameter is initialized to `"unary"`, which disables document frequency weighting.
#'
#' The pipeop works as follows:
#' 1. Words are tokenized using [`quanteda::tokens`].
#' 2. Ngrams are computed using [`quanteda::tokens_ngrams`]
#' 3. A document-frequency matrix is computed using [`quanteda::dfm`]
#' 4. The document-frequency matrix is trimmed using [`quanteda::dfm_trim`] during train-time.
#' 5. The document-frequency matrix is re-weighted (similar to [`quanteda::dfm_tfidf`]) if `scheme_df` is not set to `"unary"`.
#'
#' @section Construction:
#' ```
#' PipeOpTextVectorizer$new(id = "textvectorizer", param_vals = list())
#' ```
#' * `id` :: `character(1)`\cr
#' Identifier of resulting object, default `"textvectorizer"`.
#' * `param_vals` :: named `list`\cr
#' List of hyperparameter settings, overwriting the hyperparameter settings that would otherwise be set during construction. Default `list()`.
#'
#' @section Input and Output Channels:
#' Input and output channels are inherited from [`PipeOpTaskPreproc`].
#'
#' The output is the input [`Task`][mlr3::Task] with all affected features converted to a bag-of-words
#' representation.
#'
#' @section State:
#' The `$state` is a list with element 'cols': A vector of extracted columns.
#'
#' @section Parameters:
#' The parameters are the parameters inherited from [`PipeOpTaskPreproc`], as well as:
#'
#' * `return_type` :: `character(1)`\cr
#' Whether to return an integer representation ("integer-sequence") or a Bag-of-words ("bow").
#' If set to "integer_sequence", tokens are replaced by an integer and padded/truncated to `sequence_length`.
#' If set to "factor_sequence", tokens are replaced by a factor and padded/truncated to `sequence_length`.
#' If set to 'bow', a possibly weighted bag-of-words matrix is returned.
#' Defaults to `bow`.
#'
#' * `stopwords_language` :: `character(1)`\cr
#' Language to use for stopword filtering. Needs to be either `"none"`, a language identifier listed in
#' `stopwords::stopwords_getlanguages("snowball")` (`"de"`, `"en"`, ...) or `"smart"`.
#' `"none"` disables language-specific stopwords.
#' `"smart"` coresponds to `stopwords::stopwords(source = "smart")`, which
#' contains *English* stopwords and also removes one-character strings. Initialized to `"smart"`.\cr
#' * `extra_stopwords` :: `character`\cr
#' Extra stopwords to remove. Must be a `character` vector containing individual tokens to remove. Initialized to `character(0)`.
#' When `n` is set to values greater than 1, this can also contain stop-ngrams.
#'
#' * `tolower` :: `logical(1)`\cr
#' Convert to lower case? See [`quanteda::dfm`]. Default: `TRUE`.
#' * `stem` :: `logical(1)`\cr
#' Perform stemming? See [`quanteda::dfm`]. Default: `FALSE`.
#'
#' * `what` :: `character(1)`\cr
#' Tokenization splitter. See [`quanteda::tokens`]. Default: `word`.
#' * `remove_punct` :: `logical(1)`\cr
#' See [`quanteda::tokens`]. Default: `FALSE`.
#' * `remove_url` :: `logical(1)`\cr
#' See [`quanteda::tokens`]. Default: `FALSE`.
#' * `remove_symbols` :: `logical(1)`\cr
#' See [`quanteda::tokens`]. Default: `FALSE`.
#' * `remove_numbers` :: `logical(1)`\cr
#' See [`quanteda::tokens`]. Default: `FALSE`.
#' * `remove_separators` :: `logical(1)`\cr
#' See [`quanteda::tokens`]. Default: `TRUE`.
#' * `split_hypens` :: `logical(1)`\cr
#' See [`quanteda::tokens`]. Default: `FALSE`.
#'
#' * `n` :: `integer`\cr
#' Vector of ngram lengths. See [`quanteda::tokens_ngrams`]. Initialized to 1, deviating from the base function's default.
#' Note that this can be a *vector* of multiple values, to construct ngrams of multiple orders.
#' * `skip` :: `integer`\cr
#' Vector of skips. See [`quanteda::tokens_ngrams`]. Default: 0. Note that this can be a *vector* of multiple values.
#'
#' * `sparsity` :: `numeric(1)`\cr
#' Desired sparsity of the 'tfm' matrix. See [`quanteda::dfm_trim`]. Default: `NULL`.
#' * `max_termfreq` :: `numeric(1)`\cr
#' Maximum term frequency in the 'tfm' matrix. See [`quanteda::dfm_trim`]. Default: `NULL`.
#' * `min_termfreq` :: `numeric(1)`\cr
#' Minimum term frequency in the 'tfm' matrix. See [`quanteda::dfm_trim`]. Default: `NULL`.
#' * `termfreq_type` :: `character(1)`\cr
#' How to asess term frequency. See [`quanteda::dfm_trim`]. Default: `"count"`.
#'
#' * `scheme_df` :: `character(1)` \cr
#' Weighting scheme for document frequency: See [`quanteda::docfreq`]. Initialized to `"unary"` (1 for each document, deviating from base function default).
#' * `smoothing_df` :: `numeric(1)`\cr
#' See [`quanteda::docfreq`]. Default: 0.
#' * `k_df` :: `numeric(1)`\cr
#' `k` parameter given to [`quanteda::docfreq`] (see there).
#' Default is 0.
#' * `threshold_df` :: `numeric(1)`\cr
#' See [`quanteda::docfreq`]. Default: 0. Only considered for `scheme_df` = `"count"`.
#' * `base_df` :: `numeric(1)`\cr
#' The base for logarithms in [`quanteda::docfreq`] (see there). Default: 10.
#'
#' * `scheme_tf` :: `character(1)` \cr
#' Weighting scheme for term frequency: See [`quanteda::dfm_weight`]. Default: `"count"`.
#' * `k_tf` :: `numeric(1)`\cr
#' `k` parameter given to [`quanteda::dfm_weight`] (see there).
#' Default behaviour is 0.5.
#' * `base_df` :: `numeric(1)`\cr
#' The base for logarithms in [`quanteda::dfm_weight`] (see there). Default: 10.
#'
#' #' * `sequence_length` :: `integer(1)`\cr
#' The length of the integer sequence. Defaults to `Inf`, i.e. all texts are padded to the length
#' of the longest text. Only relevant for "return_type" : "integer_sequence"
#'
#' @section Internals:
#' See Description. Internally uses the `quanteda` package. Calls [`quanteda::tokens`], [`quanteda::tokens_ngrams`] and [`quanteda::dfm`]. During training,
#' [`quanteda::dfm_trim`] is also called. Tokens not seen during training are dropped during prediction.
#'
#' @section Methods:
#' Only methods inherited from [`PipeOpTaskPreproc`]/[`PipeOp`].
#'
#' @examples
#' library("mlr3")
#' library("data.table")
#' # create some text data
#' dt = data.table(
#' txt = replicate(150, paste0(sample(letters, 3), collapse = " "))
#' )
#' task = tsk("iris")$cbind(dt)
#'
#' pos = po("textvectorizer", param_vals = list(stopwords_language = "en"))
#'
#' pos$train(list(task))[[1]]$data()
#'
#' one_line_of_iris = task$filter(13)
#'
#' one_line_of_iris$data()
#'
#' pos$predict(list(one_line_of_iris))[[1]]$data()
#' @family PipeOps
#' @include PipeOpTaskPreproc.R
#' @export
PipeOpTextVectorizer = R6Class("PipeOpTextVectorizer",
inherit = PipeOpTaskPreproc,
public = list(
initialize = function(id = "textvectorizer", param_vals = list()) {
ps = ParamSet$new(params = list(
ParamFct$new("stopwords_language", tags = c("train", "predict"),
levels = c("da", "de", "en", "es", "fi", "fr", "hu", "ir", "it",
"nl", "no", "pt", "ro", "ru", "sv" , "smart", "none")),
ParamUty$new("extra_stopwords", tags = c("train", "predict"), custom_check = check_character),
ParamLgl$new("tolower", default = TRUE, tags = c("train", "predict", "dfm")),
ParamLgl$new("stem", default = FALSE, tags = c("train", "predict", "dfm")),
ParamFct$new("what", default = "word", tags = c("train", "predict", "tokenizer"),
levels = c("word", "word1", "fasterword", "fastestword", "character", "sentence")),
ParamLgl$new("remove_punct", default = FALSE, tags = c("train", "predict", "tokenizer")),
ParamLgl$new("remove_symbols", default = FALSE, tags = c("train", "predict", "tokenizer")),
ParamLgl$new("remove_numbers", default = FALSE, tags = c("train", "predict", "tokenizer")),
ParamLgl$new("remove_url", default = FALSE, tags = c("train", "predict", "tokenizer")),
ParamLgl$new("remove_separators", default = TRUE, tags = c("train", "predict", "tokenizer")),
ParamLgl$new("split_hyphens", default = FALSE, tags = c("train", "predict", "tokenizer")),
ParamUty$new("n", default = 2, tags = c("train", "predict", "ngrams"), custom_check = curry(check_integerish, min.len = 1, lower = 1, any.missing = FALSE)),
ParamUty$new("skip", default = 0, tags = c("train", "predict", "ngrams"), custom_check = curry(check_integerish, min.len = 1, lower = 0, any.missing = FALSE)),
ParamDbl$new("sparsity", lower = 0, upper = 1, default = NULL,
tags = c("train", "dfm_trim"), special_vals = list(NULL)),
ParamFct$new("termfreq_type", default = "count", tags = c("train", "dfm_trim"),
levels = c("count", "prop", "rank", "quantile")),
ParamDbl$new("min_termfreq", lower = 0, default = NULL,
tags = c("train", "dfm_trim"), special_vals = list(NULL)),
ParamDbl$new("max_termfreq", lower = 0, default = NULL,
tags = c("train", "dfm_trim"), special_vals = list(NULL)),
ParamFct$new("scheme_df", default = "count", tags = c("train", "docfreq"),
levels = c("count", "inverse", "inversemax", "inverseprob", "unary")),
ParamDbl$new("smoothing_df", lower = 0, default = 0, tags = c("train", "docfreq")),
ParamDbl$new("k_df", lower = 0, tags = c("train", "docfreq")),
ParamDbl$new("threshold_df", lower = 0, default = 0, tags = c("train", "docfreq")),
ParamDbl$new("base_df", lower = 0, default = 10, tags = c("train", "docfreq")),
ParamFct$new("scheme_tf", default = "count", tags = c("train", "predict", "dfm_weight"),
levels = c("count", "prop", "propmax", "logcount", "boolean", "augmented", "logave")),
ParamDbl$new("k_tf", lower = 0, upper = 1, tags = c("train", "predict", "dfm_weight")),
ParamDbl$new("base_tf", lower = 0, default = 10, tags = c("train", "predict", "dfm_weight")),
ParamFct$new("return_type", default = "bow", levels = c("bow", "integer_sequence", "factor_sequence"), tags = c("train", "predict")),
ParamInt$new("sequence_length", default = 0, lower = 0, upper = Inf, tags = c("train", "predict", "integer_sequence"))
))$
add_dep("base_df", "scheme_df", CondAnyOf$new(c("inverse", "inversemax", "inverseprob")))$
add_dep("smoothing_df", "scheme_df", CondAnyOf$new(c("inverse", "inversemax", "inverseprob")))$
add_dep("k_df", "scheme_df", CondAnyOf$new(c("inverse", "inversemax", "inverseprob")))$
add_dep("base_df", "scheme_df", CondAnyOf$new(c("inverse", "inversemax", "inverseprob")))$
add_dep("threshold_df", "scheme_df", CondEqual$new("count"))$
add_dep("k_tf", "scheme_tf", CondEqual$new("augmented"))$
add_dep("base_tf", "scheme_tf", CondAnyOf$new(c("logcount", "logave")))$
add_dep("scheme_tf", "return_type", CondEqual$new("bow"))$
add_dep("sparsity", "return_type", CondEqual$new("bow"))$
add_dep("sequence_length", "return_type", CondAnyOf$new(c("integer_sequence", "factor_sequence")))
ps$values = list(stopwords_language = "smart", extra_stopwords = character(0), n = 1, scheme_df = "unary", return_type = "bow")
super$initialize(id = id, param_set = ps, param_vals = param_vals, packages = c("quanteda", "stopwords"), feature_types = "character")
}
),
private = list(
.train_task = function(task) {
fn = task$feature_names
dt = task$data()[, -1]
colwise_results = lapply(dt, function(column) {
tkn = private$.transform_tokens(column)
tdm = private$.transform_bow(tkn, trim = TRUE) # transform to BOW (bag of words), return term count matrix
state = list(
tdm = quanteda::dfm_subset(tdm, FALSE), # empty tdm so we have vocab of training data
docfreq = invoke(quanteda::docfreq, .args = c(list(x = tdm), # column weights
rename_list(self$param_set$get_values(tags = "docfreq"), "_df$", "")))
)
if (self$param_set$values$return_type == "bow") {
matrix_res = quanteda::convert(private$.transform_tfidf(tdm, state$docfreq), "matrix")
} else {
matrix_res = private$.transform_integer_sequence(tkn, tdm, state)
}
#list(state = state, matrix = matrix_res)
matrix_res
})
#self$state = list(colmodels = map(colwise_results, "state"))
colwise_results <- do.call(cbind, colwise_results)
task$select(setdiff(fn ,fn))$cbind(as_data_backend(Matrix(colwise_results)))
task
},
.predict_task = function(task) {
fn = task$feature_names
dt = task$data()[, -1]
colwise_results = imap(dt, function(column, colname) {
state = self$state$colmodels[[colname]]
if (nrow(d)) {
tkn = private$.transform_tokens(column)
tdm = private$.transform_bow(tkn, trim = TRUE)
tdm = rbind(tdm, state$tdm) # make sure all columns occur
tdm = tdm[, colnames(state$tdm)] # Ensure only train-time features are passed on
if (self$param_set$values$return_type == "bow") {
tdm = quanteda::convert(private$.transform_tfidf(tdm, state$docfreq), "matrix")
} else {
tdm = private$.transform_integer_sequence(tkn, tdm, state)
}
} else {
tdm = quanteda::convert(state$tdm, "matrix")
}
tdm
}) %>%
do.call(what = cbind)
task$select(setdiff(fn ,fn))$cbind(as_data_backend(Matrix(colwise_results)))
},
# text: character vector of feature column
.transform_tokens = function(text) {
corpus = quanteda::corpus(text)
# tokenize
tkn = invoke(quanteda::tokens, .args = c(list(x = corpus), self$param_set$get_values(tags = "tokenizer")))
invoke(quanteda::tokens_ngrams, .args = c(list(x = tkn), self$param_set$get_values(tags = "ngrams")))
},
# tkn: tokenized text, result from `.transform_tokens`
# trim: TRUE during training: trim infrequent features
.transform_bow = function(tkn, trim) {
pv = self$param_set$get_values()
remove = NULL
if (pv$stopwords_language != "none") {
if (pv$stopwords_language == "smart") {
remove = stopwords::stopwords(source = "smart")
} else {
remove = stopwords::stopwords(language = self$param_set$get_values()$stopwords_language)
}
}
remove = c(remove, pv$extra_stopwords)
# document-feature matrix
tdm = invoke(quanteda::dfm, .args = c(list(x = tkn, remove = remove), self$param_set$get_values(tags = "dfm")))
# trim rare tokens
if (trim) {
invoke(quanteda::dfm_trim, .args = c(list(x = tdm), self$param_set$get_values(tags = "dfm_trim")))
} else {
tdm
}
},
.transform_integer_sequence = function(tkn, tdm, state) {
# List of allowed tokens:
pv = insert_named(list(min_termfreq = 0, max_termfreq = Inf), self$param_set$get_values(tags = "dfm_trim"))
dt = data.table(data.table(feature = names(state$docfreq), frequency = state$docfreq))
tokens = unname(unclass(tkn))
dict = attr(tokens, "types")
dict = setkeyv(data.table(k = dict, v = seq_along(dict)), "k")
dict = dict[dt][pv$min_termfreq < get("frequency") & get("frequency") < pv$max_termfreq,]
# pad or cut x to length l
pad0 = function(x, l) {
c(x[seq_len(min(length(x), l))], rep(0, max(0, l - length(x))))
}
il = self$param_set$values$sequence_length
if (is.null(il)) il = max(map_int(tokens, length))
tokens = map(tokens, function(x) {
x = pad0(ifelse(x %in% dict$v, x, 0), il)
data.table(matrix(x, nrow = 1))
})
tokens = rbindlist(tokens)
if (self$param_set$values$return_type == "factor_sequence") {
tokens = map_dtc(tokens, factor, levels = dict$v[!is.na(dict$v)], labels = dict$k[!is.na(dict$v)])
}
tokens
},
.transform_tfidf = function(tdm, docfreq) {
if (!quanteda::nfeat(tdm)) return(tdm)
# code copied from quanteda:::dfm_tfidf.dfm (adapting here to avoid train/test leakage)
x = invoke(quanteda::dfm_weight, .args = c(list(x = tdm),
rename_list(self$param_set$get_values("dfm_weight"), "_tf$", "")))
v = docfreq
j = methods::as(x, "dgTMatrix")@j + 1L
x@x = x@x * v[j]
x
}
)
)
mlr_pipeops$add("textvectorizer", PipeOpTextVectorizer)
Original PipeOpTextVectorizer
, now PipeOpTextVectorizerOrig
:
#' @title PipeOpTextVectorizerOrig
#'
#' @usage NULL
#' @name mlr_pipeops_textvectorizerorig
#' @format [`R6Class`] object inheriting from [`PipeOpTaskPreproc`]/[`PipeOp`].
#'
#' @description
#' Computes a bag-of-word representation from a (set of) columns.
#' Columns of type `character` are split up into words.
#' Uses the [`quanteda::dfm()`][quanteda::dfm],
#' [`quanteda::dfm_trim()`][quanteda::dfm_trim] from the 'quanteda' package.
#' TF-IDF computation works similarly to [`quanteda::dfm_tfidf()`][quanteda::dfm_tfidf]
#' but has been adjusted for train/test data split using [`quanteda::docfreq()`][quanteda::docfreq]
#' and [`quanteda::dfm_weight()`][quanteda::dfm_weight]
#'
#' In short:
#' * Per default, produces a bag-of-words representation
#' * If `n` is set to values > 1, ngrams are computed
#' * If `df_trim` parameters are set, the bag-of-words is trimmed.
#' * The `scheme_tf` parameter controls term-frequency (per-document, i.e. per-row) weighting
#' * The `scheme_df` parameter controls the document-frequency (per token, i.e. per-column) weighting.
#'
#' Parameters specify arguments to quanteda's `dfm`, `dfm_trim`, `docfreq` and `dfm_weight`.
#' What belongs to what can be obtained from each params `tags` where `tokenizer` are
#' arguments passed on to [`quanteda::dfm()`][quanteda::dfm].
#' Defaults to a bag-of-words representation with token counts as matrix entries.
#'
#' In order to perform the *default* `dfm_tfidf` weighting, set the `scheme_df` parameter to `"inverse"`.
#' The `scheme_df` parameter is initialized to `"unary"`, which disables document frequency weighting.
#'
#' The pipeop works as follows:
#' 1. Words are tokenized using [`quanteda::tokens`].
#' 2. Ngrams are computed using [`quanteda::tokens_ngrams`]
#' 3. A document-frequency matrix is computed using [`quanteda::dfm`]
#' 4. The document-frequency matrix is trimmed using [`quanteda::dfm_trim`] during train-time.
#' 5. The document-frequency matrix is re-weighted (similar to [`quanteda::dfm_tfidf`]) if `scheme_df` is not set to `"unary"`.
#'
#' @section Construction:
#' ```
#' PipeOpTextVectorizerOrig$new(id = "textvectorizerorig", param_vals = list())
#' ```
#' * `id` :: `character(1)`\cr
#' Identifier of resulting object, default `"textvectorizerorig"`.
#' * `param_vals` :: named `list`\cr
#' List of hyperparameter settings, overwriting the hyperparameter settings that would otherwise be set during construction. Default `list()`.
#'
#' @section Input and Output Channels:
#' Input and output channels are inherited from [`PipeOpTaskPreproc`].
#'
#' The output is the input [`Task`][mlr3::Task] with all affected features converted to a bag-of-words
#' representation.
#'
#' @section State:
#' The `$state` is a list with element 'cols': A vector of extracted columns.
#'
#' @section Parameters:
#' The parameters are the parameters inherited from [`PipeOpTaskPreproc`], as well as:
#'
#' * `return_type` :: `character(1)`\cr
#' Whether to return an integer representation ("integer-sequence") or a Bag-of-words ("bow").
#' If set to "integer_sequence", tokens are replaced by an integer and padded/truncated to `sequence_length`.
#' If set to "factor_sequence", tokens are replaced by a factor and padded/truncated to `sequence_length`.
#' If set to 'bow', a possibly weighted bag-of-words matrix is returned.
#' Defaults to `bow`.
#'
#' * `stopwords_language` :: `character(1)`\cr
#' Language to use for stopword filtering. Needs to be either `"none"`, a language identifier listed in
#' `stopwords::stopwords_getlanguages("snowball")` (`"de"`, `"en"`, ...) or `"smart"`.
#' `"none"` disables language-specific stopwords.
#' `"smart"` coresponds to `stopwords::stopwords(source = "smart")`, which
#' contains *English* stopwords and also removes one-character strings. Initialized to `"smart"`.\cr
#' * `extra_stopwords` :: `character`\cr
#' Extra stopwords to remove. Must be a `character` vector containing individual tokens to remove. Initialized to `character(0)`.
#' When `n` is set to values greater than 1, this can also contain stop-ngrams.
#'
#' * `tolower` :: `logical(1)`\cr
#' Convert to lower case? See [`quanteda::dfm`]. Default: `TRUE`.
#' * `stem` :: `logical(1)`\cr
#' Perform stemming? See [`quanteda::dfm`]. Default: `FALSE`.
#'
#' * `what` :: `character(1)`\cr
#' Tokenization splitter. See [`quanteda::tokens`]. Default: `word`.
#' * `remove_punct` :: `logical(1)`\cr
#' See [`quanteda::tokens`]. Default: `FALSE`.
#' * `remove_url` :: `logical(1)`\cr
#' See [`quanteda::tokens`]. Default: `FALSE`.
#' * `remove_symbols` :: `logical(1)`\cr
#' See [`quanteda::tokens`]. Default: `FALSE`.
#' * `remove_numbers` :: `logical(1)`\cr
#' See [`quanteda::tokens`]. Default: `FALSE`.
#' * `remove_separators` :: `logical(1)`\cr
#' See [`quanteda::tokens`]. Default: `TRUE`.
#' * `split_hypens` :: `logical(1)`\cr
#' See [`quanteda::tokens`]. Default: `FALSE`.
#'
#' * `n` :: `integer`\cr
#' Vector of ngram lengths. See [`quanteda::tokens_ngrams`]. Initialized to 1, deviating from the base function's default.
#' Note that this can be a *vector* of multiple values, to construct ngrams of multiple orders.
#' * `skip` :: `integer`\cr
#' Vector of skips. See [`quanteda::tokens_ngrams`]. Default: 0. Note that this can be a *vector* of multiple values.
#'
#' * `sparsity` :: `numeric(1)`\cr
#' Desired sparsity of the 'tfm' matrix. See [`quanteda::dfm_trim`]. Default: `NULL`.
#' * `max_termfreq` :: `numeric(1)`\cr
#' Maximum term frequency in the 'tfm' matrix. See [`quanteda::dfm_trim`]. Default: `NULL`.
#' * `min_termfreq` :: `numeric(1)`\cr
#' Minimum term frequency in the 'tfm' matrix. See [`quanteda::dfm_trim`]. Default: `NULL`.
#' * `termfreq_type` :: `character(1)`\cr
#' How to asess term frequency. See [`quanteda::dfm_trim`]. Default: `"count"`.
#'
#' * `scheme_df` :: `character(1)` \cr
#' Weighting scheme for document frequency: See [`quanteda::docfreq`]. Initialized to `"unary"` (1 for each document, deviating from base function default).
#' * `smoothing_df` :: `numeric(1)`\cr
#' See [`quanteda::docfreq`]. Default: 0.
#' * `k_df` :: `numeric(1)`\cr
#' `k` parameter given to [`quanteda::docfreq`] (see there).
#' Default is 0.
#' * `threshold_df` :: `numeric(1)`\cr
#' See [`quanteda::docfreq`]. Default: 0. Only considered for `scheme_df` = `"count"`.
#' * `base_df` :: `numeric(1)`\cr
#' The base for logarithms in [`quanteda::docfreq`] (see there). Default: 10.
#'
#' * `scheme_tf` :: `character(1)` \cr
#' Weighting scheme for term frequency: See [`quanteda::dfm_weight`]. Default: `"count"`.
#' * `k_tf` :: `numeric(1)`\cr
#' `k` parameter given to [`quanteda::dfm_weight`] (see there).
#' Default behaviour is 0.5.
#' * `base_df` :: `numeric(1)`\cr
#' The base for logarithms in [`quanteda::dfm_weight`] (see there). Default: 10.
#'
#' #' * `sequence_length` :: `integer(1)`\cr
#' The length of the integer sequence. Defaults to `Inf`, i.e. all texts are padded to the length
#' of the longest text. Only relevant for "return_type" : "integer_sequence"
#'
#' @section Internals:
#' See Description. Internally uses the `quanteda` package. Calls [`quanteda::tokens`], [`quanteda::tokens_ngrams`] and [`quanteda::dfm`]. During training,
#' [`quanteda::dfm_trim`] is also called. Tokens not seen during training are dropped during prediction.
#'
#' @section Methods:
#' Only methods inherited from [`PipeOpTaskPreproc`]/[`PipeOp`].
#'
#' @examples
#' library("mlr3")
#' library("data.table")
#' # create some text data
#' dt = data.table(
#' txt = replicate(150, paste0(sample(letters, 3), collapse = " "))
#' )
#' task = tsk("iris")$cbind(dt)
#'
#' pos = po("textvectorizerorig", param_vals = list(stopwords_language = "en"))
#'
#' pos$train(list(task))[[1]]$data()
#'
#' one_line_of_iris = task$filter(13)
#'
#' one_line_of_iris$data()
#'
#' pos$predict(list(one_line_of_iris))[[1]]$data()
#' @family PipeOps
#' @include PipeOpTaskPreproc.R
#' @export
PipeOpTextVectorizerOrig = R6Class("PipeOpTextVectorizerOrig",
inherit = PipeOpTaskPreproc,
public = list(
initialize = function(id = "textvectorizerorig", param_vals = list()) {
ps = ParamSet$new(params = list(
ParamFct$new("stopwords_language", tags = c("train", "predict"),
levels = c("da", "de", "en", "es", "fi", "fr", "hu", "ir", "it",
"nl", "no", "pt", "ro", "ru", "sv" , "smart", "none")),
ParamUty$new("extra_stopwords", tags = c("train", "predict"), custom_check = check_character),
ParamLgl$new("tolower", default = TRUE, tags = c("train", "predict", "dfm")),
ParamLgl$new("stem", default = FALSE, tags = c("train", "predict", "dfm")),
ParamFct$new("what", default = "word", tags = c("train", "predict", "tokenizer"),
levels = c("word", "word1", "fasterword", "fastestword", "character", "sentence")),
ParamLgl$new("remove_punct", default = FALSE, tags = c("train", "predict", "tokenizer")),
ParamLgl$new("remove_symbols", default = FALSE, tags = c("train", "predict", "tokenizer")),
ParamLgl$new("remove_numbers", default = FALSE, tags = c("train", "predict", "tokenizer")),
ParamLgl$new("remove_url", default = FALSE, tags = c("train", "predict", "tokenizer")),
ParamLgl$new("remove_separators", default = TRUE, tags = c("train", "predict", "tokenizer")),
ParamLgl$new("split_hyphens", default = FALSE, tags = c("train", "predict", "tokenizer")),
ParamUty$new("n", default = 2, tags = c("train", "predict", "ngrams"), custom_check = curry(check_integerish, min.len = 1, lower = 1, any.missing = FALSE)),
ParamUty$new("skip", default = 0, tags = c("train", "predict", "ngrams"), custom_check = curry(check_integerish, min.len = 1, lower = 0, any.missing = FALSE)),
ParamDbl$new("sparsity", lower = 0, upper = 1, default = NULL,
tags = c("train", "dfm_trim"), special_vals = list(NULL)),
ParamFct$new("termfreq_type", default = "count", tags = c("train", "dfm_trim"),
levels = c("count", "prop", "rank", "quantile")),
ParamDbl$new("min_termfreq", lower = 0, default = NULL,
tags = c("train", "dfm_trim"), special_vals = list(NULL)),
ParamDbl$new("max_termfreq", lower = 0, default = NULL,
tags = c("train", "dfm_trim"), special_vals = list(NULL)),
ParamFct$new("scheme_df", default = "count", tags = c("train", "docfreq"),
levels = c("count", "inverse", "inversemax", "inverseprob", "unary")),
ParamDbl$new("smoothing_df", lower = 0, default = 0, tags = c("train", "docfreq")),
ParamDbl$new("k_df", lower = 0, tags = c("train", "docfreq")),
ParamDbl$new("threshold_df", lower = 0, default = 0, tags = c("train", "docfreq")),
ParamDbl$new("base_df", lower = 0, default = 10, tags = c("train", "docfreq")),
ParamFct$new("scheme_tf", default = "count", tags = c("train", "predict", "dfm_weight"),
levels = c("count", "prop", "propmax", "logcount", "boolean", "augmented", "logave")),
ParamDbl$new("k_tf", lower = 0, upper = 1, tags = c("train", "predict", "dfm_weight")),
ParamDbl$new("base_tf", lower = 0, default = 10, tags = c("train", "predict", "dfm_weight")),
ParamFct$new("return_type", default = "bow", levels = c("bow", "integer_sequence", "factor_sequence"), tags = c("train", "predict")),
ParamInt$new("sequence_length", default = 0, lower = 0, upper = Inf, tags = c("train", "predict", "integer_sequence"))
))$
add_dep("base_df", "scheme_df", CondAnyOf$new(c("inverse", "inversemax", "inverseprob")))$
add_dep("smoothing_df", "scheme_df", CondAnyOf$new(c("inverse", "inversemax", "inverseprob")))$
add_dep("k_df", "scheme_df", CondAnyOf$new(c("inverse", "inversemax", "inverseprob")))$
add_dep("base_df", "scheme_df", CondAnyOf$new(c("inverse", "inversemax", "inverseprob")))$
add_dep("threshold_df", "scheme_df", CondEqual$new("count"))$
add_dep("k_tf", "scheme_tf", CondEqual$new("augmented"))$
add_dep("base_tf", "scheme_tf", CondAnyOf$new(c("logcount", "logave")))$
add_dep("scheme_tf", "return_type", CondEqual$new("bow"))$
add_dep("sparsity", "return_type", CondEqual$new("bow"))$
add_dep("sequence_length", "return_type", CondAnyOf$new(c("integer_sequence", "factor_sequence")))
ps$values = list(stopwords_language = "smart", extra_stopwords = character(0), n = 1, scheme_df = "unary", return_type = "bow")
super$initialize(id = id, param_set = ps, param_vals = param_vals, packages = c("quanteda", "stopwords"), feature_types = "character")
}
),
private = list(
.train_dt = function(dt, levels, target) {
colwise_results = sapply(dt, function(column) {
tkn = private$.transform_tokens(column)
tdm = private$.transform_bow(tkn, trim = TRUE) # transform to BOW (bag of words), return term count matrix
state = list(
tdm = quanteda::dfm_subset(tdm, FALSE), # empty tdm so we have vocab of training data
docfreq = invoke(quanteda::docfreq, .args = c(list(x = tdm), # column weights
rename_list(self$param_set$get_values(tags = "docfreq"), "_df$", "")))
)
if (self$param_set$values$return_type == "bow") {
matrix = quanteda::convert(private$.transform_tfidf(tdm, state$docfreq), "matrix")
} else {
matrix = private$.transform_integer_sequence(tkn, tdm, state)
}
list(state = state, matrix = matrix)
}, simplify = FALSE)
self$state = list(colmodels = map(colwise_results, "state"))
as.data.frame(map(colwise_results, "matrix"))
},
.predict_dt = function(dt, levels, target) {
colwise_results = imap(dt, function(column, colname) {
state = self$state$colmodels[[colname]]
if (nrow(dt)) {
tkn = private$.transform_tokens(column)
tdm = private$.transform_bow(tkn, trim = TRUE)
tdm = rbind(tdm, state$tdm) # make sure all columns occur
tdm = tdm[, colnames(state$tdm)] # Ensure only train-time features are pased on
if (self$param_set$values$return_type == "bow") {
tdm = quanteda::convert(private$.transform_tfidf(tdm, state$docfreq), "matrix")
} else {
tdm = private$.transform_integer_sequence(tkn, tdm, state)
}
} else {
tdm = quanteda::convert(state$tdm, "matrix")
}
tdm
})
as.data.frame(colwise_results)
},
# text: character vector of feature column
.transform_tokens = function(text) {
corpus = quanteda::corpus(text)
# tokenize
tkn = invoke(quanteda::tokens, .args = c(list(x = corpus), self$param_set$get_values(tags = "tokenizer")))
invoke(quanteda::tokens_ngrams, .args = c(list(x = tkn), self$param_set$get_values(tags = "ngrams")))
},
# tkn: tokenized text, result from `.transform_tokens`
# trim: TRUE during training: trim infrequent features
.transform_bow = function(tkn, trim) {
pv = self$param_set$get_values()
remove = NULL
if (pv$stopwords_language != "none") {
if (pv$stopwords_language == "smart") {
remove = stopwords::stopwords(source = "smart")
} else {
remove = stopwords::stopwords(language = self$param_set$get_values()$stopwords_language)
}
}
remove = c(remove, pv$extra_stopwords)
# document-feature matrix
tdm = invoke(quanteda::dfm, .args = c(list(x = tkn, remove = remove), self$param_set$get_values(tags = "dfm")))
# trim rare tokens
if (trim) {
invoke(quanteda::dfm_trim, .args = c(list(x = tdm), self$param_set$get_values(tags = "dfm_trim")))
} else {
tdm
}
},
.transform_integer_sequence = function(tkn, tdm, state) {
# List of allowed tokens:
pv = insert_named(list(min_termfreq = 0, max_termfreq = Inf), self$param_set$get_values(tags = "dfm_trim"))
dt = data.table(data.table(feature = names(state$docfreq), frequency = state$docfreq))
tokens = unname(unclass(tkn))
dict = attr(tokens, "types")
dict = setkeyv(data.table(k = dict, v = seq_along(dict)), "k")
dict = dict[dt][pv$min_termfreq < get("frequency") & get("frequency") < pv$max_termfreq,]
# pad or cut x to length l
pad0 = function(x, l) {
c(x[seq_len(min(length(x), l))], rep(0, max(0, l - length(x))))
}
il = self$param_set$values$sequence_length
if (is.null(il)) il = max(map_int(tokens, length))
tokens = map(tokens, function(x) {
x = pad0(ifelse(x %in% dict$v, x, 0), il)
data.table(matrix(x, nrow = 1))
})
tokens = rbindlist(tokens)
if (self$param_set$values$return_type == "factor_sequence") {
tokens = map_dtc(tokens, factor, levels = dict$v[!is.na(dict$v)], labels = dict$k[!is.na(dict$v)])
}
tokens
},
.transform_tfidf = function(tdm, docfreq) {
if (!quanteda::nfeat(tdm)) return(tdm)
# code copied from quanteda:::dfm_tfidf.dfm (adapting here to avoid train/test leakage)
x = invoke(quanteda::dfm_weight, .args = c(list(x = tdm),
rename_list(self$param_set$get_values("dfm_weight"), "_tf$", "")))
v = docfreq
j = methods::as(x, "dgTMatrix")@j + 1L
x@x = x@x * v[j]
x
}
)
)
mlr_pipeops$add("textvectorizerorig", PipeOpTextVectorizerOrig)
Benchmark the two pipe operators:
# Internal functions for pipops to run
# See https://rdrr.io/cran/mlr3pipelines/src/R/utils.R#sym-rename_list
curry = function(fn, ..., varname = "x") {
arguments = list(...)
function(x) {
arguments[[varname]] = x
do.call(fn, arguments)
}
}
rename_list = function(x, ...) {
names(x) = gsub(x = names(x), ...)
x
}
po_text <- po('textvectorizer')
po_text_orig <- po('textvectorizerorig')
microbenchmark(
po_text$train(list(task)),
po_text_orig$train(list(task)),
times = 2
)
#Unit: seconds
# expr min lq mean median uq max neval
# po_text$train(list(task)) 6.609644 6.609644 7.484634 7.484634 8.359624 8.359624 2
# po_text_orig$train(list(task)) 8.586773 8.586773 9.383135 9.383135 10.179497 10.179497 2
Hey, sorry for the silence. Feel free to ping me again if I forget to respond on time.
So I guess the only sensible pipeops-solution (1) seems to be to create a dfm
backend. What we do not know is whether a Matrix
backend would solve the memory bottleneck.
An alternative (2) would be to just fuse the learner and pipeop into a single large learner which you also proposed already I think.
As dfm
is only relevant for quanteda
-models, I think the better option is to implement (2) in a sensible manner, by i.e. copying over large parts of the PipeOp
's code as a function and then adding the different learners on top.
I guess those learners would then go into mlr3extralearners
.
Apologies for the long silence.
I'm working on solution (2), i.e. build a mlr3extralearners
version of quanteda
's Multinomial NB model that directly incorporates mlr3pipelines::PipeOpTextVectorizer
in it, to avoid the unnecessary data conversions discussed above.
As expected, the gains in computation time relative to mlr3pipelines::PipeOpTextVectorizer
are immense.
I haven't yet finalized the model though. There are a couple of issues that I need to fix:
.transform_integer_sequence
s is still slow because it's based on a data.table
format which I am planning to change. This is not such an urgent issue right now, because the use of this function is optional rather than the default.NA
probabilities, and I'm yet to discover why:library(mlr3)
library(mlr3learners)
library(mlr3pipelines)
library(quanteda)
library(quanteda.textmodels)
library(dplyr)
library(microbenchmark)
library(checkmate)
# Read util functions from mlr3pipelines or lrn("classif.textmodel_nb") below will complain it can't find them
devtools::source_url("https://raw.githubusercontent.com/mlr-org/mlr3pipelines/master/R/utils.R")
# Movie corpus data in 'corpus' format
corp_movies <- data_corpus_moviereviews
summary(corp_movies, 5)
class(corp_movies)
# Movie corpus data in 'data frame' format. Will be passed to mlr3's task function
corp_movies_df <- convert(corp_movies, to = 'data.frame') %>%
select(sentiment, text)
# Convert movie corpus data frame to task
corp_movies_task <- TaskClassif$new('movies', corp_movies_df,
target = 'sentiment')
nb <- lrn("classif.textmodel_nb")
nb$train(corp_movies_task)
preds <- nb$predict(corp_movies_task)
preds
sum(is.na(preds$data$response))
I'll see what I can do when I have the time!
dfm
does inherit from Matrix
, so using the DataBackendMatrix
works!
library("mlr3")
library("quanteda")
library("quanteda.textmodels")
library("data.table")
dfm_corpus <- dfm(data_corpus_moviereviews)
colnames(dfm_corpus) = gsub("%", "[perc]", colnames(dfm_corpus), fixed = TRUE)
b = as_data_backend(dfm_corpus, dense = data.table(TARGET = rep(1, nrow(dfm_corpus))))
options(mlr3.allow_utf8_names = TRUE)
tr = TaskRegr$new("test", b, target = "TARGET")
tr$data(rows = 1:3, cols = tr$feature_names[1:3])
In principle it should be possible (and maybe even relatively lightweight) to create a DataBackendMatrix
from the dfm
object in PipeOpTextVectorizer
, and then to pry out this object from the Task
's backend in the relevant Learner
.
Backends with different types is getting too complicated, so creating a dfm column-type is back on the menu.
Hi,
Pipe operator
mlr3pipelines::PipeOpTextVectorizer
is painfully slow in comparison withquanteda::dfm()
:I'm not sure why, but my speculation is that it has something to do with the fact that
mlr3pipelines::PipeOpTextVectorizer
creates the document-feature matrix in adata frame
(more precisely, adata.table
) format. This results in a massive table that causes memory issues.By contrast, the
dfm
format ofquanteda::dfm()
is pretty lightweight.As it currently stands,
mlr3pipelines::PipeOpTextVectorizer
limits opportunities for text analysis on a standard laptop (mine has 8GB of RAM), because it simply sucks up all the RAM even before passing the data to one or more learners.It would be great if the pipe operator could be modified to allow the user to choose between
data.table
anddfm
outputs. The latter would be then be passed toquanteda.textmodels::textmodel_nb()
, which is a freakishly fast version of multinomial Naive Bayes. I am currently working on addingquanteda.textmodels::textmodel_nb()
tomlr3extralearners
, but the data format requirement is a major blocker. My idea was to usemlr_learners_classif.textmodel_nb
in conjunction withmlr3pipelines::PipeOpTextVectorizer
in order to pass thedata.table
data from the latter to the former, and then convert the data fromdata.table
todfm
insidemlr_learners_classif.textmodel_nb
usingquanteda::as.dfm
. However, as it turns out,quanteda::as.dfm
is disappointingly slow. So, it looks like the option of havingmlr3pipelines::PipeOpTextVectorizer
output the data indfm
format is a reasonable one.Let me know what you think!
Thanks