DDPG by gymnasium 6日目

Contents

前回までの動き:

  1. agentがobsを受けてchoose_actionでactor(NN)をforwardしactionを出力する。
  2. actionを受けてenv.stepし結果としてnext_state,reward,doneを得る。
  3. agent.rememberで結果obs,action,rewerd,next_state,int(done)を保存する。
  4. rememberで64データ集まったらagent.learnで学習が始まる。
学習:
  1. sample_bufferで64データをランダムに取り出す。
  2. dtype=T.float32に変換する。
  3. target_actorへnext_states 64データを入力してtarget_actions 64データを得る。
ここまで作成しました。

今回

このtarget_actionsをnext_statesと共にtarget_criticへ入力するところからやっていきます。
この部分こそ連続値対応できるDDPGの核心部分なので十分に理解する必要があります。
引き続き 学習learn()メソド内での処理です。

やっていこう

AgentDDPG.learn()メソド内の

target_actions = self.target_actor.forward(next_states)

の直下に
target_critic_values
 = self.target_critic.forward(next_states, target_actions)
を入れます。
算出したてのtarget_actionsとnext_statesの2つを入力として、target_critic_valuesを出力します。

ちょっと説明をいれると、DDPGはTD法なのでTDターゲットとしてr + γ*V(w)[s_t+1]を考えます。target_critic_valuesはこれのことです。

この価値関数Vの部分をtarget_criticNNで表現します。

# 21.ターゲットクリティックネットワークインスタンスtarget_criticを作成します。AgentDDPG.__init__()内に定義します。
target_actorの引数 学習率alphaをcritic用にbetaへ変更しています
# 22.CriticNNクラスを作成します。
これでtarget_critic_values が返ってくる

#23.ベースラインとして機能するクリティックネットワーク(価値関数V(w)[s_t]ネットワーク)に 現在の状態observationsと行動actionsを入力してcritic_valueを算出する。
AgentDDPG.learn()メソドに戻ってさっきほどの
target_critic_values
 = self.target_critic.forward(next_states, target_actions)
の直下に
 critic_values
   = self.critic.forward(observations, actions)
を入れる。criticインスタンスはまだ作成していないので、AgentDDPG.__init__()に追加する
# 24. AgentDDPG.__init__()にクリティックインスタンス生成を追加する
これで4つのNNを導入することができた。

# 25.target_criticからTDターゲット(= r + γ*V(w)[s_t+1])を算出する。

AgentDDPG.learn()メソドに戻って、

まとめとこれまでのスクリプト

actorが行動して集めたデータから、経験再生を使って「次の状態」からtarget_actorが「次の行動」を出力し、「次の行動」と「次の状態」からtarget_criticが「次の状態価値」出力し、TDターゲットを算出しました。注意すべきはここで言う「次の行動」とはあくまでtarget_actorが生み出した「架空の行動」です。

一方で、criticは経験再生を使って「現在の状態」と「そのとき取った行動」から「現在の状態価値」を算出ししました。注意すべきは、こちらの「そのとき取った行動」とは実際にactorが行動して経験再生バッファに保存されたデータです。

また、この「現在の状態価値」をベースラインと呼びます。次回以降。「TDターゲット-ベースライン」の演算が出てくるので注目です。

ではまた次回