ContextLab / hypertools

A Python toolbox for gaining geometric insights into high-dimensional data
http://hypertools.readthedocs.io/en/latest/
MIT License
1.83k stars 160 forks source link

Plotting text #179

Closed andrewheusser closed 6 years ago

andrewheusser commented 6 years ago

This PR adds the ability to plot text data. For example:

data = [['i like cats alot', 'cats r pretty cool', 'cats are better than dogs'],
        ['dogs rule the haus', 'dogs are my jam', 'dogs are a mans best friend']]
hyp.plot(data,'o')

yields a plot where each dot represents a sentence that was vectorized using sklearn's CountVectorizer and then modeled using LatentDirichletAllocation.

To plot just the vectorized text, simply set hyp.plot(data, text_model=None)

I exposed the hyp.tools.text2mat function to the user, and that's what does the heavy lifting. It can vectorize the data using CountVectorizer or TfidfVectorizer and model the data using LDA or NMF.

jeremymanning commented 6 years ago

This looks awesome!

Is there (or should there be) a default text model (e.g. the Wikipedia model we've been kicking around)? Perhaps we could do something like:

jeremymanning commented 6 years ago

Another thought: should we support word2vec?

andrewheusser commented 6 years ago

Supporting a wiki model would be a great idea, esp for short texts like tweets. Currently, text_model refers to the type of model used for the text data (LDA or NMF), and in both cases, the model is derived from the input data. text_model=None simply skips the modeling step and returns the vectorized text, which may be desired in some situations. So, maybe just adding atext_model='wiki' option and a text_model=<model object> would be sufficient. I also think supporting word2vec would be awesome, but need to do a little research to see what's available

jeremymanning commented 6 years ago

For word2vec, did you see this library I linked to? https://github.com/danielfrg/word2vec

For text_model, it seems worth defining those arguments similarly to how reduce_model, etc. are defined-- e.g.:

I also haven't fully thought through what happens if some of the data gets specified as matrices, and other data gets specified as text. E.g.:

andrewheusser commented 6 years ago

Ah, I didn't see your word2vec link, but I have used that particular library before. In the same way that we will have a predefined topic model fit to wiki data, we could have a predefined word2vec model fit to the same wiki dataset. If a sentence is passed as the input, its clear to me what the output of a topic model would be. However, word2vec outputs a vector for each word. Would we a) average the vectors together, b) plot a separate point for each word or c) some other behavior?

if the user doesn't specify anything (or defines as None), we default to a pre-selected and pre-trained model (e.g. LDA with wikipedia-derived topics) If the user sets reduce=None, the analyze function simply returns the data with no dimensionality reduction. However, the default behavior is to reduce with IncrementalPCA. We could default to either using a pretrained model, or deriving the model from the input data, I'm happy with either option, but I guess i'd lean toward a predefined model because it will typically be more stable unless the input data is large.

if the user specifies a string (for a supported model), we use that model with pre-selected parameters (e.g. number of features, etc.) ✔️

if the user specifies a dictionary (whose keys are arguments), we use that to fill in any defined features, reverting to defaults for whatever the user doesn't specify ✔️

I also haven't fully thought through what happens if some of the data gets specified as matrices, and other data gets specified as text. In my current implementation, this behavior is not supported, but if we can make it work, it may be useful

after fitting the text model, this means that different datasets might have different numbers of dimensions...so how do we deal with that? one idea is to force the text model to have the same number of dims as the other numerical matrices. However, this wouldn't work if a predefined model. 🤔

do we support count matrices as an alternative way of inputting "text" data? I think so? if you input a samples by words count matrix to hyp.plot, it will create a plot.

how do we display the results to the user? e.g. can they view the vocabulary and/or the topics? perhaps these (via a model object) should go in the DataGeometry object somehow? and then when new text data are passed to geo.plot, we could use the fitted model to compute topics for the new text data. In the current implementation, the user would have to pass the vocab/text samples as labels (it's not done automatically). They don't have access to the vocab/topics. I don't see an intuitive place to store this info in the datageometry objects, so we may have to think about adding a field if we want to support this behavior. One this to clarify is whether we want to treat the text_model as the reduce model, or keep them separate. A reason to keep them separate is that you might want to fit the text data to a topic model with say 20 topics, but then reduce the data down to 3 dims with PCA or another reduction alg.

andrewheusser commented 6 years ago

I'm working on the wikipedia-derived topic model. Right now, I am using the wikipedia python package to retrieve the text of pages that were specified in the Matlab wiki model. Then, I am planning to use the sklearn CountVectorizer and LatentDirichletAllocation to fit a topic model where n_topics=100. @jeremymanning what should I use for the alpha parameter? it defaults to 1/n_topics, but the matlab version seems to have used 25/n_topics:

with being the number of topics and being how the alpha parameter was set (I included 'est', where it is estimated from data, and 'gs25', where it is set to 25/K). The latter approach seems to lead to more interpretable models, from a very subjective perspective, and is close to what Griffiths 2007 suggests.

Any other parameters I should change from default?

andrewheusser commented 6 years ago

I created a model with 100 topics and the rest default params, except for the learning_method, which I changed from 'online' to 'batch' bc it threw a warning that the default will change in the next release. It's pretty large (500 mb). How do we want to handle the data? There are a few options I can think of, with varying times to implement:

1) Create a local folder for data, and if the model isn't there download it 2) Load it on the fly from google drive 3) Manually load it (e.g. wiki= hyp.load('wiki'); hyp.plot(text, text_model=wiki))

1 seems like the most elegant to me, but so far we've avoided having a local data folder.

andrewheusser commented 6 years ago

The latest: I wrote a text2mat function, which takes a list (or list of lists) of text samples as input and converts them to matrices using a vectorizer (by setting the vectorizer parameter to count or tfidf, or custom) followed by a text model (by setting the text_model parameter to LDA or NMF or custom). The custom model can be prefit model instance (like the wiki model), or a class. The custom models must follow the scikit-learn transformer API (fit, transform, fit_transform methods).

To implement this in the plot function so that users can pass text directly to hyp.plot, i will:

jeremymanning commented 6 years ago

hey @andrewheusser -- where are we with this?

andrewheusser commented 6 years ago

almost done but i still want to add the option to use the fit wiki model. i've created the model but i haven't set up a way to access it via download or load from disk if it exists

jeremymanning commented 6 years ago

:+1: got it-- thanks!

andrewheusser commented 6 years ago

@jeremymanning i think this is finally ready for your review. here is a list of the changes made on this PR: https://github.com/ContextLab/hypertools/releases/tag/untagged-86d0fbc6541a2e29d6bb. Let me know if you have questions!

jeremymanning commented 6 years ago

I tried the this demo that you describe above

data = [['i like cats alot', 'cats r pretty cool', 'cats are better than dogs'],
        ['dogs rule the haus', 'dogs are my jam', 'dogs are a mans best friend']]
hyp.plot(data,'o')

I'm getting this error:

---------------------------------------------------------------------------
EOFError                                  Traceback (most recent call last)
<ipython-input-3-0b520719b960> in <module>()
      1 data = [['i like cats alot', 'cats r pretty cool', 'cats are better than dogs'],
      2         ['dogs rule the haus', 'dogs are my jam', 'dogs are a mans best friend']]
----> 3 hyp.plot(data,'o')

/usr/local/lib/python3.6/site-packages/hypertools/plot/plot.py in plot(x, fmt, marker, markers, linestyle, linestyles, color, colors, palette, group, hue, labels, legend, title, size, elev, azim, ndims, model, model_params, reduce, cluster, align, normalize, n_clusters, save_path, animate, duration, tail_duration, rotations, zoom, chemtrails, precog, bullettime, frame_rate, explore, show, transform, vectorizer, semantic, corpus, ax)
    246     # analyze the data
    247     if transform is None:
--> 248         raw = format_data(x, **text_args)
    249         xform = analyze(raw, ndims=ndims, normalize=normalize, reduce=reduce,
    250                     align=align, internal=True)

/usr/local/lib/python3.6/site-packages/hypertools/tools/format_data.py in format_data(x, vectorizer, semantic, corpus, ppca, text_align)
    120                 text_data.append(np.array(i).reshape(-1, 1))
    121         # convert text to numerical matrices
--> 122         text_data = text2mat(text_data, **text_args)
    123 
    124     # replace the text data with transformed data

/usr/local/lib/python3.6/site-packages/hypertools/_shared/helpers.py in memoizer(*args, **kwargs)
    169         key = str(args) + str(kwargs)
    170         if key not in cache:
--> 171             cache[key] = obj(*args, **kwargs)
    172         return cache[key]
    173     return memoizer

/usr/local/lib/python3.6/site-packages/hypertools/tools/text2mat.py in text2mat(data, vectorizer, semantic, corpus)
     80             semantic = 'LatentDirichletAllocation'
     81     elif semantic in ('wiki', 'nips', 'sotus',):
---> 82         semantic = load(semantic + '_model')
     83         vectorizer = None
     84         model_is_fit = True

/usr/local/lib/python3.6/site-packages/hypertools/tools/load.py in load(dataset, reduce, ndims, align, normalize, download)
    108         data = DataGeometry(**geo)
    109     elif dataset in datadict.keys():
--> 110         data = _load_data(dataset, datadict[dataset])
    111     else:
    112         raise RuntimeError('No data loaded. Please specify a .geo file or '

/usr/local/lib/python3.6/site-packages/hypertools/tools/load.py in _load_data(dataset, fileid)
    146         data = _load_from_disk(dataset)
    147     else:
--> 148         data = _load_from_disk(dataset)
    149     return data
    150 

/usr/local/lib/python3.6/site-packages/hypertools/tools/load.py in _load_from_disk(dataset)
    174         try:
    175             with open(fullpath, 'rb') as f:
--> 176                 return pickle.load(f)
    177         except ValueError as e:
    178             print(e)

EOFError: Ran out of input
andrewheusser commented 6 years ago

@jeremymanning - i modified the load function with a try statement to attempt to load in an example dataset and if that fails, redownload the dataset and load it in. I think this should fix the issue you were having above

andrewheusser commented 6 years ago

i think this is ready to merge now!

jeremymanning commented 6 years ago

This looks great! However, I found a couple of bugs (I think):

sotus = hyp.load('sotus')
hyp.plot(sotus, '.') #why are the dots different colors?  how is coloring determined?
hyp.plot(sotus) #nothing shows up-- but I think this should result in a line plot
jeremymanning commented 6 years ago

^ typos corrected above

andrewheusser commented 6 years ago

hmm, i think this is the expected behavior. sotus is a geo, so you can just do:

geo = hyp.load('sotus')
geo.plot()

although, the way you didn't isn't wrong because hyp.plot can handle geos. The colors are different for different groups of dots because the data is parsed up into a list of numpy arrays, where each array contains a different president's sotus e.g. [bush1, bush2, clinton...]. The labels do not show up because when you pass a geo to hyp.plot, the default arguments are applied (and the default is labels=None). We could change this such that hyp.plot(geo) just calls geo.plot() internally, but then any arguments that are input would have to be ignored i think.

andrewheusser commented 6 years ago

ah, i didn't see the second one. thats definitely a bug haha

jeremymanning commented 6 years ago

@andrewheusser is the bug now squashed or should i hold off on further review?

andrewheusser commented 6 years ago

Not squashed! hold off and I'll tackle it after CNS

On Sat, Mar 24, 2018 at 3:21 PM Jeremy Manning notifications@github.com wrote:

@andrewheusser https://github.com/andrewheusser is the bug now squashed or should i hold off on further review?

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/ContextLab/hypertools/pull/179#issuecomment-375918073, or mute the thread https://github.com/notifications/unsubscribe-auth/AGSn19VLvVx7i7qNvuDgHbv95gGYovWwks5thpydgaJpZM4Re_Xn .

andrewheusser commented 6 years ago

now that i'm thinking about this...i'm wondering about our design decision to support handling of geos in hyp.plot. It doesn't seem necessary considering that all geos have the plot method already attached to them. for example:

sotus = hyp.load('sotus')
sotus.plot() # this is the intended API and works
hyp.plot(sotus) # the line plots don't show up 

before i dig into why, i wanted to see if there was a good reason to support geo as an input format for hyp.plot. One reason i can think of is that it resets the default arguments, allowing the user to create a new geo with all of the default arguments, but that's really the only difference i think..

andrewheusser commented 6 years ago

ah - i figured out why hyp.plot(sotus) doesn't work. after processing the data is list of 1x3 matrices, and each matrix has only 1 coordinate (and you need 2 coordinates to draw a line) so matplotlib doesn't draw anything. for this to be drawn as a line, the text data would need to be input as a list of lists of strings, instead of a list of strings.

jeremymanning commented 6 years ago

Shouldn't we support lists of strings in addition to lists of lists of strings? A list of strings is the analog of a single array or dataframe, and a list of lists of strings is like a list of arrays/dataframes

andrewheusser commented 6 years ago

we do support a list of strings (in addition to lists of lists of strings). its that for lists of strings, each string is treated as a document and transformed to a single point. this works fine with point plots, but line plots wont work.

this is actually not specific to text. if you do something like: hyp.plot([np.random.rand(1, 10) for i in range(10)]) it will create an empty plot because of the way that matplotlib handles drawing lines from single points (it doesnt draw anything).

possible solution: if arr.shape[0]==1, format as a dot instead of a line

jeremymanning commented 6 years ago

This makes sense... But we should reformat the sotu data so that it works as intended (if that's not already done). I like the "force plotting a dot if only one observation" solution.

jeremymanning commented 6 years ago

Python 3 notes (same behavior as 2.7 unless noted below):

andrewheusser commented 6 years ago

Plotting a single string or a produces no plot-- e.g. hyp.plot('this is a test'). Same with plotting a single list of length 1 with a string (hyp.plot(['this is a test'])). Both of these should either plot a single point, or output a warning that the use case is not supported. The issue was more a general bug with handling datasets where the nrows < ndims. Now, if dimensionality cannot be performed (nrows==1 across all datasets), a warning is thrown and zeros of shape (1, ndims) are returned, elif the number of rows < ndims, a warning is thrown that says the data will be reduced to the number of rows

Plotting a list of length 2 (both with strings) also produces no output-- e.g. hyp.plot(['this is a test', 'is it not?']). I'm not sure what the expected behavior is, but I was thinking I'd get a line...? In general, I was thinking each string represents one document (a point), and each list represents a collection of documents (a trajectory). If the user passes a list of lists of strings, that reflects multiple collections of documents. This is now fixed such that if a list of strings is passed, a line will be plotted. if nrows<ndims (as is the case in this example, a warning will be thrown that the data dimensionality will be reduced to nrows

By the same logic, this should produce a line (for the first document collection) and a single point (for the second document collection): hyp.plot([['this is a test', 'is it not?'], ['yes, i think it is a test']]) This (correctly) plots two trajectories as expected: hyp.plot([['this is a test', 'is it not?'], ['yes, i think it is a test', 'but i don''t like tests!']]) -- so something seems off about the above examples (possibly the same issue related to plotting a single point for a single document, even if the user specifies a line) if the shape of an array is 1 x something, we now plot a point, even when line is specified as the format string (which is the default).

I think this should result in each of wiki, nips, and weights being plotted in different colors (or possibly each element of those data structures being plotted in different colors): hyp.plot([wiki, nips, weights], '.') When regenerating the text geos, i accidentally saved the nips data in the wiki geo...its corrected now. i.e. wiki is all plotted in one color, nips in another and weights plotted in many colors (because its a list of matrices). (you'll have to clear your cache to make it work if you want to test it out again).

Future feature request: make geo objects iterable and indexible, and possibly have them be extensions of numpy arrays or dataframes Adding an issue

hyp.load "no data loaded" error should be updated to include new (text) datasets Done

describe function seems to not be working-- e.g. hyp.describe(sotus) Works now!

should we get rid of describe_pca? Yes, we had a warning in the last version that it is deprecated. Done.

jeremymanning commented 6 years ago

almost everything looks good-- except hyp.describe(sotus) still isn't working for me:

/Users/jmanning/Library/Enthought/Canopy_64bit/User/lib/python2.7/site-packages/hypertools/tools/describe.py:62: UserWarning: When input data is large, this computation can take a long time.
  warnings.warn('When input data is large, this computation can take a long time.')
/Users/jmanning/Library/Enthought/Canopy_64bit/User/lib/python2.7/site-packages/seaborn/timeseries.py:183: UserWarning: The tsplot function is deprecated and will be removed or replaced (in a substantially altered version) in a future release.
  warnings.warn(msg, UserWarning)

Seems like a seaborn thing? Maybe as simple as updating the requirements list...

andrewheusser commented 6 years ago

hmmm, it works fine for me! the seaborn thing is just a warning that they will be deprecating tsplot (which the desccribe function uses). are you getting an error somewhere else?

jeremymanning commented 6 years ago

Ah. Here is the actual error (previous test was w/ python 2; here is the python 3 error):

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-12-003ba24b692e> in <module>()
----> 1 hyp.describe(sotus)

/usr/local/lib/python3.6/site-packages/hypertools/tools/describe.py in describe(x, reduce, max_dims, show, format_data)
    100     if show:
    101         fig, ax = plt.subplots()
--> 102         ax = sns.tsplot(data=result['individual'], time=[i for i in range(2, max_dims+2)], err_style="unit_traces")
    103         ax.set_title('Correlation with raw data by number of components')
    104         ax.set_ylabel('Correlation')

/usr/local/lib/python3.6/site-packages/seaborn/timeseries.py in tsplot(data, time, unit, condition, value, err_style, ci, interpolate, color, estimator, n_boot, err_palette, err_kws, legend, ax, **kwargs)
    264                                  time=times,
    265                                  unit=units,
--> 266                                  cond=conds))
    267 
    268     # Set up the err_style and ci arguments for the loop below

/usr/local/lib/python3.6/site-packages/pandas/core/frame.py in __init__(self, data, index, columns, dtype, copy)
    328                                  dtype=dtype, copy=copy)
    329         elif isinstance(data, dict):
--> 330             mgr = self._init_dict(data, index, columns, dtype=dtype)
    331         elif isinstance(data, ma.MaskedArray):
    332             import numpy.ma.mrecords as mrecords

/usr/local/lib/python3.6/site-packages/pandas/core/frame.py in _init_dict(self, data, index, columns, dtype)
    459             arrays = [data[k] for k in keys]
    460 
--> 461         return _arrays_to_mgr(arrays, data_names, index, columns, dtype=dtype)
    462 
    463     def _init_ndarray(self, values, index, columns, dtype=None, copy=False):

/usr/local/lib/python3.6/site-packages/pandas/core/frame.py in _arrays_to_mgr(arrays, arr_names, index, columns, dtype)
   6161     # figure out the index, if necessary
   6162     if index is None:
-> 6163         index = extract_index(arrays)
   6164     else:
   6165         index = _ensure_index(index)

/usr/local/lib/python3.6/site-packages/pandas/core/frame.py in extract_index(data)
   6209             lengths = list(set(raw_lengths))
   6210             if len(lengths) > 1:
-> 6211                 raise ValueError('arrays must all be same length')
   6212 
   6213             if have_dicts:

ValueError: arrays must all be same length
andrewheusser commented 6 years ago

ah, can you try clearing your cache? you probably have an old version of the sotus dataset. the cache is in /Users/yourname/hypertools_data/

jeremymanning commented 6 years ago

that worked! let's either clear the cache on installation of this version or add something to the documentation warning users of this issue. merging...