Double Backpropagationについて | moskomule log

Double Backpropagationについて

はじめに

PyTorch v0.2では”Higher order gradients” (double backpropagation)がサポートされました.Chainerもv3においてこれがサポートされます.今回Chainer Meetupの資料を読んで雰囲気が分かったのでまとめました.

筆者は長くdouble backpropagationという名称から

\[\mathrm{loss}\longrightarrow \frac{\partial^2 \mathrm{loss}}{\partial x_i \partial x_j} \]

と思い込んでいました.そう思っているのでdocumentを読んでもいまいちよく分からない.ところが上に挙げた資料では,そうではなくて

\[\mathrm{loss}=g(f(x), \frac{\partial f(x)}{\partial x})\]

のような場合にも計算が出来る,ということなのだということが説明されていて救われました.

PyTorchの例

これで以上,でもよいのですが,PyTorchでの例を.

$x=1, y=x^3, z=y^2+\frac{dy}{dx}$をとして,$\frac{dz}{dx}|_{x=1}$を求めます.

>>> x = Variable(torch.Tensor([1]), requires_grad=True)
>>> y = x ** 3
>>> grad_y, = autograd.grad(y, x, create_graph=True)
>>> (grad_y + y ** 2).backward()
>>> x.grad
Variable containing:
 12
[torch.FloatTensor of size 1]

実際,$z=x^6+3x^2$ですので$\frac{dz}{dx}|_{x=1}=(6x+6x^5)|_{x=1}=12$となります.

$\frac{dy}{dx}$つまりgrad_yautograd.gradを用いてautograd.grad(y, x, create_graph=True)によってつくっています.create_graphによって計算グラフ中にこの微分のためのグラフをつくっています.autograd.grad(outputs, inputs, *)inputs, outputsはシークエンスで与えられて,返り値もタプルなので上記のようにgrad_y,とする必要があります.

使いみち

PyTorch v0.2のリリースには

you can compute Hessian-Vector products, penalize the norm of the gradients of your model, implement Unrolled GANs and Improved WGANs, etc

とあります.

Hessian-Vector productsは函数$f(x)$のHessianと任意のベクトル$v$の積$\frac{\partial^2 f}{\partial x_i\partial x_j}v$です.

\[ \frac{\partial^2 f}{\partial x_i\partial x_j}v=\frac{\partial}{\partial x}(\frac{\partial f}{\partial x} v) \]

を用います.例えば$f(x)=3x_0^2+4x_0x_1+x_1^2$とすればHessianは$\begin{bmatrix}6 & 4 \cr 4 & 2\end{bmatrix}$ですから$v=\begin{bmatrix}1 \cr 1\end{bmatrix}$とすればHessian-Vector productは$\begin{bmatrix}10 \cr 6\end{bmatrix}$となるはずです.実際,

>>> v = Variable(torch.Tensor([1, 1]))
>>> x = Variable(torch.Tensor([0.1, 0.1]), requires_grad=True)
>>> f = 3 * x[0] ** 2 + 4 * x[0] * x[1] + x[1] **2
>>> grad_f, = autograd.grad(f, x, create_graph=True)
>>> z = grad_f @ v
>>> z.backward()
>>> x.grad
Variable containing:
 10
  6
[torch.FloatTensor of size 2]
comments powered by Disqus