When Recurrent Models Don’t Need To Be Recurrent
04 Oct 2018Introduction

The paper explores “if a well behaved RNN can be replaced by a feedforward network of comparable size without loss in performance.”

“Well behaved” is defined in terms of controltheoretic notion of stability. This roughly requires that the gradients do not explode over time.

The paper shows that under the stability assumption, feedforward networks can approximate RNNs for both training and inference. The results are empirically validated as well.
Problem Setting

Consider a general, non linear dynamical system given by a differential state transition map Φ_{w}. The hidden h_{t} = Φ_{w}(h_{t1}, x_{t}).

Assumptions:
 Φ is smooth in w and h.
 h_{0} = 0
 Φ_{w}(0, 0) = 0 (can be ensured by translation)

Stable models are the ones where Φ is contractive ie Φ_{w}(h, x)  Φ_{w}(h’, x) is less than Λ * (h  h’)

For example, in RNN, stability would require that norm(w) is less than (L_{p})^{1} where L_{p} is the Lipschitz constant of the pointwise non linearity used.

The feedforward approximation uses a finite context (of length k) and is a truncated model.

A nonparametric function f maps the output of the recurrent model to prediction. If f is desired to be a parametric model, its parameters can be pushed to the recurrent model.
Theoretical Results

For a Λcontractive system, it can be proved that for a large k (and additional Lipschitz assumptions) the difference in prediction between the recurrent and truncated mode is negligible.

If the recurrent model and truncated feedforward network are initialized at the same point and trained over the same input for Nstep, then for an optimal k, the weights of the two models would be very close in the Euclidean space. It can be shown that this small difference does not lead to large gradient differences during subsequent update steps.

This can be roughly interpreted as  if the gradient descent can train a stable recurrent network, it can also train a feedforward model and viceversa.

The stability condition is important as, without that, truncated models would be bad (even for large values of k). Further, it is difficult to show that gradient descent converges to a stationary point.