深圳幻海软件技术有限公司 欢迎您!

使用Pytorch实现强化学习——DQN算法

2023-04-19

目录一、强化学习的主要构成二、基于python的强化学习框架三、gym四、DQN算法1.经验回放2.目标网络五、使用pytorch实现DQN算法1.replaymemory2.神经网络部分3.Agent4.模型训练函数5.训练模型一、强化学习的主要构成    

目录

一、强化学习的主要构成

二、基于python的强化学习框架

三、gym

四、DQN算法

1.经验回放

2.目标网络

五、使用pytorch实现DQN算法

1.replay memory

2.神经网络部分

3.Agent

4.模型训练函数

5.训练模型


一、强化学习的主要构成

        强化学习主要由两部分组成:智能体(agent)和环境(env)。在强化学习过程中,智能体与环境一直在交互。智能体在环境里面获取某个状态后,它会利用该状态输出一个动作(action)。然后这个动作会在环境之中被执行,环境会根据智能体采取的动作,输出下一个状态以及当前这个动作带来的奖励。智能体的目的就是尽可能多地从环境中获取奖励

二、基于python的强化学习框架

        基于python的强化学习框架有很多种,具体可以见这个博主的博客:(7条消息) 【强化学习/gym】(二)一些强化学习的框架或代码_o0o_-_的博客-CSDN博客_可解释性的强化学习框架代码        本次我使用到的框架是pytorch,因为DQN算法的实现包含了部分的神经网络,这部分对我来说使用pytorch会更顺手,所以就选择了这个。

三、gym

       gym 定义了一套接口,用于描述强化学习中的环境这一概念,同时在其官方库中,包含了一些已实现的环境。

四、DQN算法

        传统的强化学习算法使用的是Q表格存储状态价值函数或者动作价值函数,但是实际应用时,问题在的环境可能有很多种状态,甚至数不清,所以这种情况下使用离散的Q表格存储价值函数会非常不合理,所以DQN(Deep Q-learning)算法,使用神经网络拟合动作价值函数

        通常DQN算法只能处理动作离散,状态连续的情况,使用神经网络拟合出动作价值函数, 然后针对动作价值函数,选择出当状态state固定的Q值最大的动作a。

DQN算法有两个特点:

1.经验回放

        每一次的样本都放到样本池中,所以可以多次反复的使用一个样本,重复利用。训练时一次随机抽取多个数据样本来进行训练。

2.目标网络

        DQN算法的更新目标时让逼近, 但是如果两个Q使用一个网络计算,那么Q的目标值也在不断改变, 容易造成神经网络训练的不稳定。DQN使用目标网络,训练时目标值Q使用目标网络来计算,目标网络的参数定时和训练网络的参数同步。

五、使用pytorch实现DQN算法

  1. import time
  2. import random
  3. import torch
  4. from torch import nn
  5. from torch import optim
  6. import gym
  7. import numpy as np
  8. import matplotlib.pyplot as plt
  9. from collections import deque, namedtuple # 队列类型
  10. from tqdm import tqdm # 绘制进度条用
  11. device = torch. Device("cuda" if torch.cuda.is_available() else "cpu")
  12. Transition = namedtuple('Transition', ('state', 'action', 'reward', 'next_state', 'done'))

1.replay memory

  1. class ReplayMemory(object):
  2. def __init__(self, memory_size):
  3. self.memory = deque([], maxlen=memory_size)
  4. def sample(self, batch_size):
  5. batch_data = random.sample(self.memory, batch_size)
  6. state, action, reward, next_state, done = zip(*batch_data)
  7. return state, action, reward, next_state, done
  8. def push(self, *args):
  9. # *args: 把传进来的所有参数都打包起来生成元组形式
  10. # self.push(1, 2, 3, 4, 5)
  11. # args = (1, 2, 3, 4, 5)
  12. self.memory.append(Transition(*args))
  13. def __len__(self):
  14. return len(self.memory)

2.神经网络部分

  1. class Qnet(nn.Module):
  2. def __init__(self, n_observations, n_actions):
  3. super(Qnet, self).__init__()
  4. self.model = nn.Sequential(
  5. nn.Linear(n_observations, 128),
  6. nn.ReLU(),
  7. nn.Linear(128, n_actions)
  8. )
  9. def forward(self, state):
  10. return self.model(state)

3.Agent

  1. class Agent(object):
  2. def __init__(self, observation_dim, action_dim, gamma, lr, epsilon, target_update):
  3. self.action_dim = action_dim
  4. self.q_net = Qnet(observation_dim, action_dim).to(device)
  5. self.target_q_net = Qnet(observation_dim, action_dim).to(device)
  6. self.gamma = gamma
  7. self.lr = lr
  8. self.epsilon = epsilon
  9. self.target_update = target_update
  10. self.count = 0
  11. self.optimizer = optim.Adam(params=self.q_net.parameters(), lr=lr)
  12. self.loss = nn.MSELoss()
  13. def take_action(self, state):
  14. if np.random.uniform(0, 1) < 1 - self.epsilon:
  15. state = torch.tensor(state, dtype=torch.float).to(device)
  16. action = torch.argmax(self.q_net(state)).item()
  17. else:
  18. action = np.random.choice(self.action_dim)
  19. return action
  20. def update(self, transition_dict):
  21. states = transition_dict.state
  22. actions = np.expand_dims(transition_dict.action, axis=-1) # 扩充维度
  23. rewards = np.expand_dims(transition_dict.reward, axis=-1) # 扩充维度
  24. next_states = transition_dict.next_state
  25. dones = np.expand_dims(transition_dict.done, axis=-1) # 扩充维度
  26. states = torch.tensor(states, dtype=torch.float).to(device)
  27. actions = torch.tensor(actions, dtype=torch.int64).to(device)
  28. rewards = torch.tensor(rewards, dtype=torch.float).to(device)
  29. next_states = torch.tensor(next_states, dtype=torch.float).to(device)
  30. dones = torch.tensor(dones, dtype=torch.float).to(device)
  31. # update q_values
  32. # gather(1, acitons)意思是dim=1按行号索引, index=actions
  33. # actions=[[1, 2], [0, 1]] 意思是索引出[[第一行第2个元素, 第1行第3个元素],[第2行第1个元素, 第2行第2个元素]]
  34. # 相反,如果是这样
  35. # gather(0, acitons)意思是dim=0按列号索引, index=actions
  36. # actions=[[1, 2], [0, 1]] 意思是索引出[[第一列第2个元素, 第2列第3个元素],[第1列第1个元素, 第2列第2个元素]]
  37. # states.shape(64, 4) actions.shape(64, 1), 每一行是一个样本,所以这里用dim=1很合适
  38. predict_q_values = self.q_net(states).gather(1, actions)
  39. with torch.no_grad():
  40. # max(1) 即 max(dim=1)在行向找最大值,这样的话shape(64, ), 所以再加一个view(-1, 1)扩增至(64, 1)
  41. max_next_q_values = self.target_q_net(next_states).max(1)[0].view(-1, 1)
  42. q_targets = rewards + self.gamma * max_next_q_values * (1 - dones)
  43. l = self.loss(predict_q_values, q_targets)
  44. self.optimizer.zero_grad()
  45. l.backward()
  46. self.optimizer.step()
  47. if self.count % self.target_update == 0:
  48. # copy model parameters
  49. self.target_q_net.load_state_dict(self.q_net.state_dict())
  50. self.count += 1

4.模型训练函数

  1. def run_episode(env, agent, repalymemory, batch_size):
  2. state = env.reset()
  3. reward_total = 0
  4. while True:
  5. action = agent.take_action(state)
  6. next_state, reward, done, _ = env.step(action)
  7. # print(reward)
  8. repalymemory.push(state, action, reward, next_state, done)
  9. reward_total += reward
  10. if len(repalymemory) > batch_size:
  11. state_batch, action_batch, reward_batch, next_state_batch, done_batch = repalymemory.sample(batch_size)
  12. T_data = Transition(state_batch, action_batch, reward_batch, next_state_batch, done_batch)
  13. # print(T_data)
  14. agent.update(T_data)
  15. state = next_state
  16. if done:
  17. break
  18. return reward_total
  19. def episode_evaluate(env, agent, render):
  20. reward_list = []
  21. for i in range(5):
  22. state = env.reset()
  23. reward_episode = 0
  24. while True:
  25. action = agent.take_action(state)
  26. next_state, reward, done, _ = env.step(action)
  27. reward_episode += reward
  28. state = next_state
  29. if done:
  30. break
  31. if render:
  32. env.render()
  33. reward_list.append(reward_episode)
  34. return np.mean(reward_list).item()
  35. def test(env, agent, delay_time):
  36. state = env.reset()
  37. reward_episode = 0
  38. while True:
  39. action = agent.take_action(state)
  40. next_state, reward, done, _ = env.step(action)
  41. reward_episode += reward
  42. state = next_state
  43. if done:
  44. break
  45. env.render()
  46. time. Sleep(delay_time)

5.训练模型

模型训练使用到的环境时gym提供的CartPole游戏(具体可以看这里:Cart Pole - Gym Documentation (gymlibrary.dev)),这个环境比较经典,小车运行结束的要求有三个:

(1)杆子的角度超过

(2)小车位置大于 ±2.4(小车中心到达显示屏边缘)

(3)小车移动步数超过200(v1是500)

小车每走一步奖励就会+1,所以在v0版本环境中,小车一次episode的最大奖励为200

  1. if __name__ == "__main__":
  2. # print("prepare for RL")
  3. env = gym.make("CartPole-v0")
  4. env_name = "CartPole-v0"
  5. observation_n, action_n = env.observation_space.shape[0], env.action_space.n
  6. # print(observation_n, action_n)
  7. agent = Agent(observation_n, action_n, gamma=0.98, lr=2e-3, epsilon=0.01, target_update=10)
  8. replaymemory = ReplayMemory(memory_size=10000)
  9. batch_size = 64
  10. num_episodes = 200
  11. reward_list = []
  12. # print("start to train model")
  13. # 显示10个进度条
  14. for i in range(10):
  15. with tqdm(total=int(num_episodes/10), desc="Iteration %d" % i) as pbar:
  16. for episode in range(int(num_episodes / 10)):
  17. reward_episode = run_episode(env, agent, replaymemory, batch_size)
  18. reward_list.append(reward_episode)
  19. if (episode+1) % 10 == 0:
  20. test_reward = episode_evaluate(env, agent, False)
  21. # print("Episode %d, total reward: %.3f" % (episode, test_reward))
  22. pbar.set_postfix({
  23. 'episode': '%d' % (num_episodes / 10 * i + episode + 1),
  24. 'return' : '%.3f' % (test_reward)
  25. })
  26. pbar.update(1) # 更新进度条
  27. test(env, agent, 0.5) # 最后用动画观看一下效果
  28. episodes_list = list(range(len(reward_list)))
  29. plt.plot(episodes_list, reward_list)
  30. plt.xlabel('Episodes')
  31. plt.ylabel('Returns')
  32. plt.title('Double DQN on {}'.format(env_name))
  33. plt.show()

训练结果如图所示:

参考资料:

蘑菇书EasyRL (datawhalechina.github.io)

DQN 算法 (boyuai.com)

文章知识点与官方知识档案匹配,可进一步学习相关知识
算法技能树首页概览44428 人正在系统学习中