state-spaces / s4

Structured state space sequence models
Apache License 2.0
2.27k stars 280 forks source link

SequenceModel step function _setup_state missing? #87

Open FloMru opened 1 year ago

FloMru commented 1 year ago

I am building an encoder/decoder architecture where the encoder and decoder are S4 models. Since this is for a generative task I want to use the step function defined in the SequenceModel class for inference.

When I call the function I get: File "state-spaces/statespaces/models/sequence/model.py", line 129, in step x, state = layer.step(x, state=prev_state, kwargs) File "state-spaces/src/models/sequence/block.py", line 117, in step y, state = self.layer.step(y, state, kwargs) File "/state-spaces/src/models/sequence/ss/s4.py", line 253, in step y, next_state = self.kernel.step(u, state) # (B C H) File "/state-spaces/src/models/sequence/ss/kernel.py", line 1127, in step y, state = self.kernel.step(u, state, **kwargs) File "/state-spaces/src/models/sequence/ss/kernel.py", line 862, in step next_state = contract("h n, b h n -> b h n", self.dA, state) \ File "/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1269, in getattr raise AttributeError("'{}' object has no attribute '{}'".format( AttributeError: 'SSKernelDiag' object has no attribute 'dA'. Did you mean: '_A'?

Could the problem be, that the _setup_step function for the kernel is not called? If yes, what would be a practical way of calling the function, while using the SequenceModel class?

Thanks Flo

albertfgu commented 1 year ago

You have to call a setup function first. See the generation script for an example: https://github.com/HazyResearch/state-spaces/blob/06dbbdfd0876501a7f12bf3262121badbc7658af/generate.py#L190

Let me know if that works.

Also note that using the step unroll during training may be very slow depending on the model and data.