概要
GRU(Gated Recurrent Unit)は、LSTMのゲート機構を簡略化した系列モデルである。2014年にChoらが提案した。LSTMが3つのゲートとセル状態・隠れ状態の2ベクトルを持つのに対し、GRUはリセットゲートと更新ゲートの2つだけで長期記憶を実現し、隠れ状態のみで動作する。パラメータ数が約25%少なく、同等の性能を多くのタスクで達成する。
直感・モチベーション
LSTMは勾配消失を解決したが、3つのゲートと2つの状態ベクトルによりパラメータ数が多く計算コストも高い。「もっと少ないゲートで同じことができないか」という動機でGRUが設計された。
GRUのアイデアは2点に集約される。
- 更新ゲート: LSTMの忘却ゲートと入力ゲートを1つに統合する。「どれだけ過去を引き継ぐか」と「どれだけ新情報を取り込むか」をトレードオフの関係に固定する
- リセットゲート: 候補隠れ状態を計算するとき、過去の隠れ状態をどれだけ参照するかを制御する。短期的なパターンを捉えるときは過去をリセットする
セル状態を廃止して隠れ状態だけにしても、更新ゲートが を直接加算路として保持するため、勾配消失への耐性はLSTMと同程度になる。
数学的定式化
時刻 の入力 、前ステップの隠れ状態 に対して:
ここで はシグモイド関数、 は要素積。
導出を見る
LSTMとの対応関係:
| LSTM | GRU | 役割 |
|---|---|---|
| 忘却ゲート | 過去をどれだけ保持するか | |
| 入力ゲート | 新情報をどれだけ取り込むか | |
| セル状態 | (隠れ状態 に統合) | 長期記憶の経路 |
更新ゲートの式 は、過去と新情報の線形補間になっている。 なら (完全に過去を引き継ぐ)、 なら (完全に新情報に置き換える)。
勾配の流れ:
第1項 は を経由しない直接経路であり、LSTMのセル状態に相当する勾配の高速道路になる。
重要な性質・注意点
- パラメータ数: LSTMの約3/4(ゲートが3→2、セル状態なし)
- LSTMとの性能比較: タスク依存。長い系列や複雑な記憶が必要なタスクではLSTMが有利なことがある
- 更新ゲートの解釈: のユニットは長期記憶を担い、 のユニットは短期的な入力に素早く反応する
まとめ
- LSTMの3ゲート+セル状態を、2ゲート+隠れ状態のみに簡略化
- 更新ゲートが忘却と入力を統合し、リセットゲートが過去参照量を制御
- パラメータ数が少なく学習が速い。多くのタスクでLSTMと同等の性能
- 設計がシンプルなため解釈・実装がしやすい
実装例
pythonimport math import random def sigmoid(x: float) -> float: return 1.0 / (1.0 + math.exp(-x)) class GRUCell: def __init__(self, input_dim: int, hidden_dim: int): self.hidden_dim = hidden_dim d = input_dim + hidden_dim scale = 0.01 # 更新ゲート・リセットゲート・候補それぞれの重み self.Wz = [[random.gauss(0, scale) for _ in range(d)] for _ in range(hidden_dim)] self.Wr = [[random.gauss(0, scale) for _ in range(d)] for _ in range(hidden_dim)] self.Wh = [[random.gauss(0, scale) for _ in range(d)] for _ in range(hidden_dim)] self.bz = [0.0] * hidden_dim self.br = [0.0] * hidden_dim self.bh = [0.0] * hidden_dim def matvec(self, W, x): return [sum(W[i][j] * x[j] for j in range(len(x))) for i in range(len(W))] def forward(self, x: list[float], h: list[float]) -> list[float]: xh = x + h H = self.hidden_dim z = [sigmoid(self.matvec(self.Wz, xh)[i] + self.bz[i]) for i in range(H)] r = [sigmoid(self.matvec(self.Wr, xh)[i] + self.br[i]) for i in range(H)] # リセットゲートで過去の隠れ状態を絞ってから候補を計算 rh = [r[i] * h[i] for i in range(H)] xrh = x + rh h_tilde = [math.tanh(self.matvec(self.Wh, xrh)[i] + self.bh[i]) for i in range(H)] h_new = [(1 - z[i]) * h[i] + z[i] * h_tilde[i] for i in range(H)] return h_new class GRU: def __init__(self, input_dim: int, hidden_dim: int): self.cell = GRUCell(input_dim, hidden_dim) self.hidden_dim = hidden_dim def forward(self, xs: list[list[float]]) -> list[list[float]]: h = [0.0] * self.hidden_dim hs = [] for x in xs: h = self.cell.forward(x, h) hs.append(h[:]) return hs # 使用例 model = GRU(input_dim=4, hidden_dim=8) xs = [[0.1, 0.2, 0.3, 0.4]] * 10 hs = model.forward(xs) print("最終隠れ状態:", hs[-1][:3], "...")
関連記事
前提知識
- RNN — ゲート機構が解決する勾配消失問題の源
- LSTM — GRUが簡略化の対象としたアーキテクチャ
派生技術
- Seq2Seq — GRUをエンコーダ・デコーダに使った系列変換モデル
- Attention — GRU/LSTMの長距離依存の限界を補う機構
応用事例
- 機械翻訳 — LSTMより軽量な代替としてSeq2Seqに使用
- 音声認識 — 音響特徴量の系列をラベル系列に変換
用語解説
| 用語 | 説明 |
|---|---|
| 更新ゲート | 過去の隠れ状態と候補隠れ状態の混合比率を決めるゲート |
| リセットゲート | 候補隠れ状態の計算で過去の隠れ状態をどれだけ参照するかを決めるゲート |
| 候補隠れ状態 | リセットゲートで絞った過去と現在の入力から計算される暫定的な隠れ状態 |
| 線形補間 | 2つのベクトルを の形で重み付き平均する操作 |
関連リンク
この記事への参照
- 比較対象RNN