numenta / htmresearch

Experimental algorithms. Unsupported.
GNU Affero General Public License v3.0
223 stars 109 forks source link

RES-885: Refactor linearsdr and cnnsdr classes #973

Closed lscheinkman closed 5 years ago

lscheinkman commented 5 years ago

@subutai Please review. Here is an example on how to use the new API:

import torch.nn as nn
import htmresearch.frameworks.pytorch.modules as htm
:
:
# Equivalent to LinearSDR class

self.l1 = htm.SparseWeights(nn.Linear(28*28, 500), 0.4)

self.kw = htm.KWinners(n=500, k=50, kInferenceFactor=1.5,
                       boostStrength=1.0, boostStrengthFactor=0.9)

and in the forward method just call self.kw(self.l1(x))

The same API is available in 2d for CNN.

lscheinkman commented 5 years ago

@subutai This refactor may break the experiments in your branches. Should I wait for you to merge your branches first before I merge this PR?

subutai commented 5 years ago

No go ahead and merge - my branch needs your changes to make progress anyway. I can update my branch afterwards.