How to Train Deep NMF Model in PyTorch

3 minute read

Published:

Recently I updated the implementation of PyTorch-NMF to make it be able to scale on large and complex NMF models. In this blog post I will briefly explain how this was done thanks to the automatic differentiation of PyTorch.

Multiplicative Update Rules with Beta Divergence

Multiplicative Update is a classic update method that has been widely used in many NMF applications. Its form is easy to derive, gaurantees a monotonic decrease of loss value, and ensures nonnegativity of the parameter updates.

Below are the multiplicative update forms when using Beta-Divergence as our criterion:


Decoupling the Derivative

The update weights are actually derived from the derivative of the criterion we choose respect to the parameter (H and W). Due to the property of Beta-Divergence, the derivative can be expressed as the difference of two nonnegative functions such that:

Then, we can simply writes:

Following the chain rule, we can also decoupling the derivative respect to parameter as (take H for example):

The derivative of WH respect to H is W^T, which is always non-negative, so the ability to decouple into two non-negative functions is actually comes from Beta-Divergence itself.

The above steps can be applied on W as well.

Derivative of Beta-Divergence

The form of Beta-Divergence is:

where P = WH and its derivative respect to P:

It is indeed composed by two non-negative functions.

Derive Weights via Back-propagation

2 Backward-Pass Algorithm

Now we can see that the two non-negative functions respect to the parameter can be viewed as two non-negative functions respect to the NMF output each multiplied by the derivative of NMF output respect to the parameter. The latter can be evaluated by PyTorch automatic differentiation, so we only need to calculate the former. After calculating the former, we just need to back-propagate the computational graph 2 times, then we can get the multiplicative update weights.

Steps

  1. Calculate the NMF output P.
  2. Given P and target V, derive the two non-negative components (pos and neg) of the derivative respect to P.
  3. Derive one non-negative components of the derivative respect to the parameter that needs to be updated by back-propagation (in PyTorch, P.backward(pos, retain_graph=True)).
  4. Derive the remaining non-negative components of the derivative by back-propagation (in PyTorch, P.backward(neg)).
  5. Derive the multiplicative update weights by dividing step 4 by step 3.

What’s the Benefit of this Approach?

Well, because most of the update weights now can be done by automatic differentiation, we can apply the following feature more easily without writing closed form solutions:

  • Advanced matrix/tensor operations: Some NMF variants (like De-convolutional NMF) use convolution instead of simple matrix multiplication to calculate the output; in PyTorch, convolution is supported natively and is fully differentiable.
  • Deeper NMF structure : Recently, some research tried to learn much higher level features by stacking multiple NMF layer by layer, which probably inspired by the rapid progress of Deep Learning in the last decade. But due to non-negative constraints, derive a closed form update solution is non-trivial. With PyTorch-NMF, as long as the gradients are all non-negative along the back-propagation path in the computational graph, we can put arbitray number of NMF layers in our model, or even more complex structure of operations, and train them jointly.

Conclusion

In this post I show you how PyTorch-NMF apply multiplicative update rules on much more advanced (or Deeper) NMF model, and I hope this project can benefits researchers from various field.

(This project is still in early developement, if you have interests to support the project, please contact me.)

Reference

  • Févotte, Cédric, and Jérôme Idier. “Algorithms for nonnegative matrix factorization with the β-divergence.” Neural computation 23.9 (2011): 2421-2456.
  • PyTorch-NMF, source code
  • PyTorch-NMF, documentation