dattalab / pyhsmm-library-models

library models built on top of pyhsmm
0 stars 1 forks source link

speed benchmark #7

Closed mattjj closed 11 years ago

mattjj commented 11 years ago

Enough of the code has changed so that it's time to benchmark again!

mattjj commented 11 years ago

Gotta pay attention to the parallel speedup-ability, too.

mattjj commented 11 years ago

I edited parallel-test.py to use two sequences of length 40k (only 2D data) and got this profile out:

Timer unit: 1e-06 s

File: /Users/mattjj/Dropbox/work/pyhsmm-library-models/pyhsmm/models.py
Function: resample_model_parallel at line 199
Total time: 32.4534 s

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   199                                               @profile
   200                                               def resample_model_parallel(self,numtoresample='all',temp=None):
   201        25          175      7.0      0.0          import parallel
   202        25           25      1.0      0.0          if numtoresample == 'all':
   203        25           27      1.1      0.0              numtoresample = len(self.states_list)
   204                                                   elif numtoresample == 'engines':
   205                                                       numtoresample = parallel.get_num_engines()
   206                                           
   207                                                   ### resample parameters locally
   208        25      3748094 149923.8     11.5          self.resample_obs_distns()
   209        25      1729157  69166.3      5.3          self.resample_trans_distn()
   210        25         2018     80.7      0.0          self.resample_init_state_distn()
   211                                           
   212                                                   ### choose which sequences to resample
   213        25          512     20.5      0.0          states_to_resample = random.sample(self.states_list,numtoresample)
   214        75           67      0.9      0.0          states_to_hold_out = [s for s in self.states_list if s not in states_to_resample]
   215                                           
   216                                                   ### resample states in parallel
   217        25           26      1.0      0.0          self.states_list = states_to_resample
   218        25     26973287 1078931.5     83.1          self.resample_states_parallel(temp=temp)
   219                                           
   220                                                   ### add back the held-out states
   221                                                   # NOTE: this might shuffle the order of states_list from the order in
   222                                                   # which data were added if numtoresample != 'all'
   223        25           46      1.8      0.0          self.states_list.extend(states_to_hold_out)

File: /Users/mattjj/Dropbox/work/pyhsmm-library-models/pyhsmm/models.py
Function: resample_states_parallel at line 225
Total time: 26.9727 s

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   225                                               @profile
   226                                               def resample_states_parallel(self,temp=None):
   227        25          135      5.4      0.0          import parallel
   228        25           18      0.7      0.0          states = self.states_list
   229        25           13      0.5      0.0          self.states_list = [] # removed because we push the global model
   230        25           27      1.1      0.0          raw_tuples = parallel.call_data_fn(
   231        25           23      0.9      0.0                  fn=self._state_sampler,
   232        75          250      3.3      0.0                  datas=[self._get_parallel_data(s) for s in states],
   233        25          150      6.0      0.0                  kwargss=self._get_parallel_kwargss(states),
   234        25     26944902 1077796.1     99.9                  engine_globals=dict(global_model=self,temp=temp),
   235                                                           )
   236        25        27141   1085.6      0.1          self._add_back_states_from_parallel(raw_tuples)

That looks like it's doing what we want, since all the time is being spent in state resampling. For comparison, here's the same thing with only one engine:

Timer unit: 1e-06 s

File: /Users/mattjj/Dropbox/work/pyhsmm-library-models/pyhsmm/models.py
Function: resample_model_parallel at line 199
Total time: 49.6725 s

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   199                                               @profile
   200                                               def resample_model_parallel(self,numtoresample='all',temp=None):
   201        25          164      6.6      0.0          import parallel
   202        25           26      1.0      0.0          if numtoresample == 'all':
   203        25           27      1.1      0.0              numtoresample = len(self.states_list)
   204                                                   elif numtoresample == 'engines':
   205                                                       numtoresample = parallel.get_num_engines()
   206                                           
   207                                                   ### resample parameters locally
   208        25      3448687 137947.5      6.9          self.resample_obs_distns()
   209        25      1363170  54526.8      2.7          self.resample_trans_distn()
   210        25         1758     70.3      0.0          self.resample_init_state_distn()
   211                                           
   212                                                   ### choose which sequences to resample
   213        25          455     18.2      0.0          states_to_resample = random.sample(self.states_list,numtoresample)
   214        75           63      0.8      0.0          states_to_hold_out = [s for s in self.states_list if s not in states_to_resample]
   215                                           
   216                                                   ### resample states in parallel
   217        25           21      0.8      0.0          self.states_list = states_to_resample
   218        25     44858063 1794322.5     90.3          self.resample_states_parallel(temp=temp)
   219                                           
   220                                                   ### add back the held-out states
   221                                                   # NOTE: this might shuffle the order of states_list from the order in
   222                                                   # which data were added if numtoresample != 'all'
   223        25           40      1.6      0.0          self.states_list.extend(states_to_hold_out)

File: /Users/mattjj/Dropbox/work/pyhsmm-library-models/pyhsmm/models.py
Function: resample_states_parallel at line 225
Total time: 44.8575 s

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   225                                               @profile
   226                                               def resample_states_parallel(self,temp=None):
   227        25          123      4.9      0.0          import parallel
   228        25           19      0.8      0.0          states = self.states_list
   229        25           20      0.8      0.0          self.states_list = [] # removed because we push the global model
   230        25           19      0.8      0.0          raw_tuples = parallel.call_data_fn(
   231        25           21      0.8      0.0                  fn=self._state_sampler,
   232        75          227      3.0      0.0                  datas=[self._get_parallel_data(s) for s in states],
   233        25          137      5.5      0.0                  kwargss=self._get_parallel_kwargss(states),
   234        25     44834277 1793371.1     99.9                  engine_globals=dict(global_model=self,temp=temp),
   235                                                           )
   236        25        22670    906.8      0.1          self._add_back_states_from_parallel(raw_tuples)

So with two engines on my laptop, the time spent in resample_states_parallel goes down from 45sec to 27sec, which is almost a 2x speedup.

If we ignore the parallel code completely and just use add_data() and resample_model(), this is the equivalent profile:

Timer unit: 1e-06 s

File: /Users/mattjj/Dropbox/work/pyhsmm-library-models/pyhsmm/models.py
Function: resample_model at line 151
Total time: 51.6126 s

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   151                                               @profile
   152                                               def resample_model(self,temp=None):
   153        25      4167688 166707.5      8.1          self.resample_obs_distns()
   154        25      1109072  44362.9      2.1          self.resample_trans_distn()
   155        25         1834     73.4      0.0          self.resample_init_state_distn()
   156        25     46334027 1853361.1     89.8          self.resample_states(temp=temp)

File: /Users/mattjj/Dropbox/work/pyhsmm-library-models/pyhsmm/models.py
Function: resample_states at line 171
Total time: 46.3335 s

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   171                                               @profile
   172                                               def resample_states(self,temp=None):
   173        75          233      3.1      0.0          for s in self.states_list:
   174        50     46333265 926665.3    100.0              s.resample(temp=temp)

Since the time spent in resample_states is 46.3 seconds, that suggests that running the parallel code with one engine didn't incur noticeable overhead compared to the pure serial code.

Alex didn't see any speedup at all on real data, though, so the next profiling task is to profile with real data.

alexbw commented 11 years ago

The exact data that I'm using can be retrieved a la:

f = np.load("/home/alexbw/Data/C57-1-data-7-28-2013-fortesting.npz") data = f['data'] means = f['means'] sigmas = f['sigmas'] labels = f['labels']

alexbw commented 11 years ago

Re-testing real data benchmarks without a full disk... freed up a few hundred GBs...

Serial speed is ~14.0 sec / iteration. 2 clients gives ~12.4 sec / iteration. 12 clients gives ~13.0 sec / iteration.

My gut trusts Matt's results, not mine here, but I can't say specifically why. Something doesn't feel right when I run these tests. I'm running stuff in other kernels on my current notebook, but I'll start the whole thing fresh in a moment.

But, if the idea was that the disk was full was what was interfering with the results, that now seems false.

mattjj commented 11 years ago

Well it's not giving a slowdown on 12 clients anymore...

What file/notebook is that in?

mattjj commented 11 years ago

Below is a profile on real data (106k frames, split into 2 chunks) from running the file real_data_speed.py with some profile decorators. The salient parts are:

  1. 57% of the time is spent in resampling obs distributions, which is not currently parallelized at all but will be soon #12
  2. Running 'import parallel' every time is totally wasteful because it's re-making the view each time! That should only run once ever globally! Free speedup! #13
  3. Pushing the model takes about 5% of the total time. That can probably be cut down through better pickling methods, but it's not a big win.

After fixing those things, re-profile (including with more engines).

Timer unit: 1e-06 s

File: /home/mattjj/work/pyhsmm-library-models/pyhsmm/models.py
Function: resample_model_parallel at line 199
Total time: 249.233 s

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   199                                               @profile
   200                                               def resample_model_parallel(self,numtoresample='all',temp=None):
   201        25          262     10.5      0.0          import parallel
   202        25           70      2.8      0.0          if numtoresample == 'all':
   203        25           71      2.8      0.0              numtoresample = len(self.states_list)
   204                                                   elif numtoresample == 'engines':
   205                                                       numtoresample = parallel.get_num_engines()
   206                                           
   207                                                   ### resample parameters locally
   208        25    144018874 5760755.0     57.8          self.resample_obs_distns()
   209        25      1844494  73779.8      0.7          self.resample_trans_distn()
   210        25         2990    119.6      0.0          self.resample_init_state_distn()
   211                                           
   212                                                   ### choose which sequences to resample
   213        25          670     26.8      0.0          states_to_resample = random.sample(self.states_list,numtoresample)
   214        75          166      2.2      0.0          states_to_hold_out = [s for s in self.states_list if s not in states_to_resample]
   215                                           
   216                                                   ### resample states in parallel
   217        25           54      2.2      0.0          self.states_list = states_to_resample
   218        25    103365184 4134607.4     41.5          self.resample_states_parallel(temp=temp)
   219                                           
   220                                                   ### add back the held-out states
   221                                                   # NOTE: this might shuffle the order of states_list from the order in
   222                                                   # which data were added if numtoresample != 'all'
   223        25           92      3.7      0.0          self.states_list.extend(states_to_hold_out)

File: /home/mattjj/work/pyhsmm-library-models/pyhsmm/models.py
Function: resample_states_parallel at line 225
Total time: 103.364 s

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   225                                               @profile
   226                                               def resample_states_parallel(self,temp=None):
   227        25          188      7.5      0.0          import parallel
   228        25           50      2.0      0.0          states = self.states_list
   229        25           48      1.9      0.0          self.states_list = [] # removed because we push the global model
   230        25           52      2.1      0.0          raw_tuples = parallel.call_data_fn(
   231        25           55      2.2      0.0                  fn=self._state_sampler,
   232        75          342      4.6      0.0                  datas=[self._get_parallel_data(s) for s in states],
   233        25          227      9.1      0.0                  kwargss=self._get_parallel_kwargss(states),
   234        25    103317948 4132717.9    100.0                  engine_globals=dict(global_model=self,temp=temp),
   235                                                           )
   236        25        45287   1811.5      0.0          self._add_back_states_from_parallel(raw_tuples)

File: /home/mattjj/work/pyhsmm-library-models/pyhsmm/parallel.py
Function: call_data_fn at line 76
Total time: 103.315 s

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    76                                           @profile
    77                                           def call_data_fn(fn,datas,kwargss=None,engine_globals=None):
    78        25          937     37.5      0.0      assert all(parallel_hash(data) in _data_to_id_dict for data in datas)
    79                                               # assert all(data_exists_on_engine(data) for data in datas) # assumes ndarray
    80                                           
    81        25           84      3.4      0.0      if engine_globals is not None:
    82        25     14016158 560646.3     13.6          dv.push(engine_globals,block=False)
    83                                           
    84        25         1267     50.7      0.0      data_ids_to_resample = set(_data_to_id(data) for data in datas)
    85                                           
    86        25          104      4.2      0.0      if kwargss is None:
    87                                                   results = dv.apply_sync(_call_data_fn,fn,data_ids_to_resample)
    88                                               else:
    89        25          944     37.8      0.0          kwargs_for_each_data = {_data_to_id(data):kwargs for data,kwargs in zip(datas,kwargss)}
    90        25     89224522 3568980.9     86.4          results = dv.apply_sync(_call_data_fn,fn,data_ids_to_resample,kwargs_for_each_data)
    91        25        69706   2788.2      0.1      c.purge_results('all')
    92        25          569     22.8      0.0      results = filter(lambda r: len(r) > 0, results)
    93                                           
    94        25          146      5.8      0.0      assert set(data_ids_to_resample) == \
    95        25          309     12.4      0.0              set(data_id for result in results for data_id,_ in result), \
    96                                                       "some data did not exist on any engine"
    97                                           
    98       125          607      4.9      0.0      return [(_id_to_data(data_id),outs) for result in results for data_id, outs in result]
mattjj commented 11 years ago

New benchmarks! Now with parallel observation distribution sampling.

On 6 sequences of length 40k each (total 240k) the time per sample went from 16.52sec in serial to 3.91sec in parallel for about a 4x speedup. (The data is 2D so it has a smaller memory footprint and hence better cache efficiency than real data.) Observation resampling is now only 10% of the time, and the bottleneck is in state sequence resampling where it probably should be.

Very promising!

One engine (serial):

Timer unit: 1e-06 s

File: /home/mattjj/work/pyhsmm-library-models/library_models.py
Function: resample_states_parallel at line 317
Total time: 371.667 s

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   317                                               @profile
   318                                               def resample_states_parallel(self,temp=None):
   319        25          181      7.2      0.0          import pyhsmm.parallel as parallel
   320        25           51      2.0      0.0          states = self.states_list
   321        25           49      2.0      0.0          self.states_list = [] # removed because we push the global model
   322        25           49      2.0      0.0          raw = parallel.map_on_each(
   323        25           52      2.1      0.0                  self._state_sampler,
   324       175          538      3.1      0.0                  [s.precomputed_likelihoods for s in states],
   325        25          334     13.4      0.0                  kwargss=self._get_parallel_kwargss(states),
   326        25    371663978 14866559.1    100.0                  engine_globals=dict(global_model=self,temp=temp), # TODO compactify
   327                                                           )
   328        25          109      4.4      0.0          self.states_list = states
   329       175          477      2.7      0.0          for s, stateseq in zip(self.states_list,raw):
   330       150         1138      7.6      0.0              s.stateseq = stateseq

File: /home/mattjj/work/pyhsmm-library-models/pyhsmm/models.py
Function: resample_model_parallel at line 192
Total time: 411.634 s

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   192                                               @profile
   193                                               def resample_model_parallel(self,numtoresample='all',temp=None):
   194        25           79      3.2      0.0          if numtoresample == 'all':
   195        25           67      2.7      0.0              numtoresample = len(self.states_list)
   196                                                   elif numtoresample == 'engines':
   197                                                       import parallel
   198                                                       numtoresample = min(parallel.get_num_engines(),len(self.states_list))
   199                                           
   200                                                   ### resample parameters locally
   201        25     33879901 1355196.0      8.2          self.resample_obs_distns_parallel()
   202        25      6078881 243155.2      1.5          self.resample_trans_distn()
   203        25         4727    189.1      0.0          self.resample_init_state_distn()
   204                                           
   205                                                   ### choose which sequences to resample
   206        25         1093     43.7      0.0          states_to_resample = random.sample(self.states_list,numtoresample)
   207       175          370      2.1      0.0          states_to_hold_out = [s for s in self.states_list if s not in states_to_resample]
   208                                           
   209                                                   ### resample states in parallel
   210        25           57      2.3      0.0          self.states_list = states_to_resample
   211        25    371668961 14866758.4     90.3          self.resample_states_parallel(temp=temp)
   212                                           
   213                                                   ### add back the held-out states
   214                                                   # NOTE: this might shuffle the order of states_list from the order in
   215                                                   # which data were added if numtoresample != 'all'
   216        25           80      3.2      0.0          self.states_list.extend(states_to_hold_out)

6 engines:

Timer unit: 1e-06 s

File: /home/mattjj/work/pyhsmm-library-models/library_models.py
Function: resample_states_parallel at line 317
Total time: 81.1456 s

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   317                                               @profile
   318                                               def resample_states_parallel(self,temp=None):
   319        25          181      7.2      0.0          import pyhsmm.parallel as parallel
   320        25           51      2.0      0.0          states = self.states_list
   321        25           49      2.0      0.0          self.states_list = [] # removed because we push the global model
   322        25           48      1.9      0.0          raw = parallel.map_on_each(
   323        25           54      2.2      0.0                  self._state_sampler,
   324       175          555      3.2      0.0                  [s.precomputed_likelihoods for s in states],
   325        25          336     13.4      0.0                  kwargss=self._get_parallel_kwargss(states),
   326        25     81142649 3245706.0    100.0                  engine_globals=dict(global_model=self,temp=temp), # TODO compactify
   327                                                           )
   328        25           82      3.3      0.0          self.states_list = states
   329       175          473      2.7      0.0          for s, stateseq in zip(self.states_list,raw):
   330       150         1129      7.5      0.0              s.stateseq = stateseq

File: /home/mattjj/work/pyhsmm-library-models/pyhsmm/models.py
Function: resample_model_parallel at line 192
Total time: 96.36 s

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   192                                               @profile
   193                                               def resample_model_parallel(self,numtoresample='all',temp=None):
   194        25           75      3.0      0.0          if numtoresample == 'all':
   195        25           70      2.8      0.0              numtoresample = len(self.states_list)
   196                                                   elif numtoresample == 'engines':
   197                                                       import parallel
   198                                                       numtoresample = min(parallel.get_num_engines(),len(self.states_list))
   199                                           
   200                                                   ### resample parameters locally
   201        25      9734719 389388.8     10.1          self.resample_obs_distns_parallel()
   202        25      5471350 218854.0      5.7          self.resample_trans_distn()
   203        25         4704    188.2      0.0          self.resample_init_state_distn()
   204                                           
   205                                                   ### choose which sequences to resample
   206        25         1106     44.2      0.0          states_to_resample = random.sample(self.states_list,numtoresample)
   207       175          371      2.1      0.0          states_to_hold_out = [s for s in self.states_list if s not in states_to_resample]
   208                                           
   209                                                   ### resample states in parallel
   210        25           51      2.0      0.0          self.states_list = states_to_resample
   211        25     81147498 3245899.9     84.2          self.resample_states_parallel(temp=temp)
   212                                           
   213                                                   ### add back the held-out states
   214                                                   # NOTE: this might shuffle the order of states_list from the order in
   215                                                   # which data were added if numtoresample != 'all'
   216        25           72      2.9      0.0          self.states_list.extend(states_to_hold_out)

File: /home/mattjj/work/pyhsmm-library-models/pyhsmm/parallel.py
Function: map_on_each at line 82
Total time: 81.1395 s

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    82                                           @profile
    83                                           def map_on_each(fn,added_datas,kwargss=None,engine_globals=None):
    84        25          122      4.9      0.0      @engine_global_namespace
    85                                               def _call(f,data_id,**kwargs):
    86                                                   return f(my_data[data_id],**kwargs)
    87                                           
    88        25           51      2.0      0.0      if engine_globals is not None:
    89        25      2452880  98115.2      3.0          dv.push(engine_globals,block=True)
    90                                           
    91        25           77      3.1      0.0      if kwargss is None:
    92                                                   kwargss = [{} for data in added_datas] # no communication overhead
    93                                           
    94       175         2186     12.5      0.0      indata = [(phash(data),data,kwargs) for data,kwargs in zip(added_datas,kwargss)]
    95        25           49      2.0      0.0      ars = [c[data_residency[data_id]].apply_async(_call,fn,data_id,**kwargs)
    96       175       289140   1652.2      0.4                      for data_id, data, kwargs in indata]
    97        25     78317053 3132682.1     96.5      dv.wait(ars)
    98       175        10018     57.2      0.0      results = [ar.get() for ar in ars]
    99                                           
   100        25        67858   2714.3      0.1      c.purge_results('all')
   101                                           
   102        25           78      3.1      0.0      return results

File: /home/mattjj/work/pyhsmm-library-models/pyhsmm/parallel.py
Function: call_with_all at line 107
Total time: 9.16861 s

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   107                                           @profile
   108                                           def call_with_all(fn,broadcasted_datas,kwargss,engine_globals=None):
   109                                               # one call for each element of kwargss
   110        25          199      8.0      0.0      @engine_global_namespace
   111                                               def _call(f,data_ids,kwargs):
   112                                                   return f([my_data[data_id] for data_id in data_ids],**kwargs)
   113                                           
   114        25           60      2.4      0.0      if engine_globals is not None:
   115         1       781993 781993.0      8.5          dv.push(engine_globals,block=True)
   116                                           
   117        25           76      3.0      0.0      results = lbv.map_sync(
   118        25           47      1.9      0.0              _call,
   119        25          141      5.6      0.0              [fn]*len(kwargss),
   120       175         2391     13.7      0.0              [[phash(data) for data in broadcasted_datas]]*len(kwargss),
   121        25      8303557 332142.3     90.6              kwargss)
   122                                           
   123        25        80047   3201.9      0.9      c.purge_results('all')
   124                                           
   125        25          103      4.1      0.0      return results
mattjj commented 11 years ago

I ran it without profiling hooks on real data just to measure the iteration time. I doubled the ~106k frames with np.tile(data,(2,1)) to measure how it would work on 212k ish frames and split that data into 6 even pieces with np.array_split, then ran on 6 engines and compared to running in serial without any parallel business at all (still split into 6 pieces, though that shouldn't matter at all).

This was all done the file real_data_speed.py on 1e99261 on jefferson.

6 engines: 6.54sec / iteration serial: 24.59sec / iteration

mattjj commented 11 years ago

We're not speeding things up right now, so I'll close this issue!