# rl_learning **Repository Path**: loxs/rl_learning ## Basic Information - **Project Name**: rl_learning - **Description**: 111111111111 - **Primary Language**: Unknown - **License**: Not specified - **Default Branch**: master - **Homepage**: None - **GVP Project**: No ## Statistics - **Stars**: 0 - **Forks**: 0 - **Created**: 2023-10-11 - **Last Updated**: 2024-05-19 ## Categories & Tags **Categories**: Uncategorized **Tags**: None ## README # 环境配置 ```bash pip install torch>=1.13.0 pip install stable-baselines3[extra] pip install pandas pip install matplotlib pip install scikit-learn pip install wandb pip install seaborn pip install tqdm pip install psutil pip install line_profiler # time cost test pip install memory_profiler # memory cost test pip install openpyxl pip install imageio[ffmpeg] pip install panda_gym pip install gym==0.23.0 pip install beautifulsoup4 # for results visualization ``` ## mujoco安装 [mujoco安装](install_mujoco-py.md) ## 修改部分库的代码 ```python # stable_baselines3/common/vec_env/dummy_vec_env.py:70 # step_wait function if self.buf_dones[env_idx]: # save final observation where user can get it, then reset self.buf_infos[env_idx]["terminal_observation"] = obs ans = self.envs[env_idx].reset() if type(ans) == tuple: obs, self.reset_infos[env_idx] = ans else: obs = ans # obs, self.reset_infos[env_idx] = self.envs[env_idx].reset() self._save_obs(env_idx, obs) # stable_baselines3/common/vec_env/dummy_vec_env.py:82 # reset function for env_idx in range(self.num_envs): maybe_options = {"options": self._options[env_idx]} if self._options[env_idx] else {} ans = self.envs[env_idx].reset(**maybe_options) if type(ans) == tuple: obs, info = ans else: obs = ans # self._save_obs(env_idx, obs) # obs, self.reset_infos[env_idx] = self.envs[env_idx].reset(**maybe_options) self._save_obs(env_idx, obs) # gymnasium/wrappers/time_limit.py:57 # step function ans = self.env.step(action) if len(ans) == 5: observation, reward, terminated,truncated, info = ans else: observation, reward, terminated, info = ans truncated = False self._elapsed_steps += 1 if self._elapsed_steps >= self._max_episode_steps: truncated = True ``` ## 下载代码 ```bash git clone https://gitee.com/loxs/rl_learning.git cd rl_learning mkdir output ``` ## 安装过程可能遇到的报错及解决 https://blog.csdn.net/dream6985/article/details/129616567 # 常用命令 `python main.py`: execute regular programs without using Wandb logs `python main.py --use_wandb`: use wandb to display logs `python main.py --use_continue_train --config_path output/2024-01-28-06:34:28-749232`: continue training from checkpoint `python main.py --config_path output/2024-01-28-06:34:28-749232` only test the model `python main.py --use_weighted_info_nce`: use weighted info nce `python main.py --encoder_eps_type last`: use only the last step size as the encoder for training encoding `kernprof -l -v main.py`: 在 `tools/main.py.lprof` analyze time costs `python -m memory_profiler main.py`: analyze memory costs `python -m tools.zip_file --input_path output/2024-01-28-06:34:28-749232`: compress result file # 数据流图 ![image](images/rl_learning_rnn.svg) 每个轨迹开始的hidden_state为**0** # LFT model ![image](images/LFT_model.svg) ![image](images/LFT_model_implement.svg) # SIFC method ![image](images/sifc_model.svg) # TODO - [X] 去掉buffer中的env_infos - [X] 在policy中使用encoder - [X] 每隔一段时间保存模型 - [X] 添加wandb - [X] 绘制损失曲线,t-sne散点图 - [X] 添加weight info_nce --use_weighted_info_nce - [X] 添加从断点开始训练 - [X] 只加对训练环境的plot - [X] 压缩文件工具 - [X] 保存测试结果到原始文件夹 - [X] 解决tnse显示bug - [X] 每个环境测试的时候显示一个视频 - [X] 测试环境脚本设置起始测试date,和截止测试date - [X] 适配SAC算法 - [X] 为环境设置hook, 为以后适配更多环境 - [X] 我们方法encoder加入policy并且可以设置权重系数 - [ ] 只保留小物块信息/机械爪信息训练encoder - [ ] 训练保留上一个回合的RNN的hidden_state - [X] k近邻方法,两点之间轨迹对比学习 LFT method - [X] 先选相似context_embed,再在里面选择奖励最大的样本作为正样本,在其他context_embed选择奖励小的样本作为负样本 - [ ] 环境改为image-based - [ ] 添加只有causal加入policy选项 - [ ] 显示weight曲线 ## 实验结果 ### domino #### 超参数设置 ```yaml model_parameters: adversarial_loss_coef: 1.0 buffer_size: 100000 causal_dim: 6 causal_hidden_dim: 128 config_path: '' contrast_batch_size: 128 contrast_buffer_type: all/one[OURSACFast] encoder_eps_type: all env_hook: DominoHook env_name: AntEnv gradient_steps: 32 learning_rate: 0.0001 method: OURSACFast/OURSACFastBase/RNNSAC seed: 100 test_eps_num_per_env: 5 time_step: 200000 train_freq: 32 use_weighted_info_nce: true ``` #### 实验结果 | 方法 | 环境 | 训练集回报 | 测试集回报 | ---- | ---- | ---- | ---- | | RNNSAC[0] | CrippleAntEnv | 167.84 | 158.81 | | RNNSAC[0] | CrippleHalfCheetahEnv | 2273.49 | 1427.05 | | RNNSAC[0] | SlimHumanoidEnv | 8607.16 | 9132.16 | | RNNSAC[0] | HalfCheetahEnv | 1659.14 | 942.96 | | RNNSAC[0] | AntEnv | 174.42 | 101.12 | | RNNSAC[1] | CrippleAntEnv | - | - | | RNNSAC[1] | CrippleHalfCheetahEnv | - | - | | RNNSAC[1] | SlimHumanoidEnv | - | - | | RNNSAC[1] | HalfCheetahEnv | - | - | | RNNSAC[1] | AntEnv | - | - | | RNNSAC[2] | CrippleAntEnv | - | - | | RNNSAC[2] | CrippleHalfCheetahEnv | - | - | | RNNSAC[2] | SlimHumanoidEnv | - | - | | RNNSAC[2] | HalfCheetahEnv | - | - | | RNNSAC[2] | AntEnv | - | - | | OURSACFastBase[0] | CrippleAntEnv | 189.81 | 130.64 | | OURSACFastBase[0] | CrippleHalfCheetahEnv | -222.36 | -212.74 | | OURSACFastBase[0] | SlimHumanoidEnv | 6781.58 | 4545.07 | | OURSACFastBase[0] | HalfCheetahEnv | 1834.43 | 1104.10 | | OURSACFastBase[0] | AntEnv | 257.27 | 198.35 | | OURSACFastBase[1] | CrippleAntEnv | - | - | | OURSACFastBase[1] | CrippleHalfCheetahEnv | - | - | | OURSACFastBase[1] | SlimHumanoidEnv | - | - | | OURSACFastBase[1] | HalfCheetahEnv | - | - | | OURSACFastBase[1] | AntEnv | - | - | | OURSACFastBase[2] | CrippleAntEnv | - | - | | OURSACFastBase[2] | CrippleHalfCheetahEnv | - | - | | OURSACFastBase[2] | SlimHumanoidEnv | - | - | | OURSACFastBase[2] | HalfCheetahEnv | - | - | | OURSACFastBase[2] | AntEnv | - | - | | OURSACFast[one-0] | CrippleAntEnv | 199.15 | 162.94 | | OURSACFast[one-0] | CrippleHalfCheetahEnv | 1965.83 | 2069.11 | | OURSACFast[one-0] | SlimHumanoidEnv | 7389.28 | 3401.16 | | OURSACFast[one-0] | HalfCheetahEnv | 1836.81 | 885.97 | | OURSACFast[one-0] | AntEnv | 35.10 | 24.70 | | OURSACFast[one-1] | CrippleAntEnv | - | - | | OURSACFast[one-1] | CrippleHalfCheetahEnv | - | - | | OURSACFast[one-1] | SlimHumanoidEnv | - | - | | OURSACFast[one-1] | HalfCheetahEnv | - | - | | OURSACFast[one-1] | AntEnv | - | - | | OURSACFast[one-2] | CrippleAntEnv | - | - | | OURSACFast[one-2] | CrippleHalfCheetahEnv | - | - | | OURSACFast[one-2] | SlimHumanoidEnv | - | - | | OURSACFast[one-2] | HalfCheetahEnv | - | - | | OURSACFast[one-2] | AntEnv | - | - | | OURSACFast[all-0] | CrippleAntEnv | 207.02 | 168.13 | | OURSACFast[all-0] | CrippleHalfCheetahEnv | 3951.40 | 2307.07 | | OURSACFast[all-0] | SlimHumanoidEnv | 3087.06 | 4886.00 | | OURSACFast[all-0] | HalfCheetahEnv | 1862.99 | 1083.29 | | OURSACFast[all-0] | AntEnv | 159.60 | 164.35 | | OURSACFast[all-1] | CrippleAntEnv | - | - | | OURSACFast[all-1] | CrippleHalfCheetahEnv | - | - | | OURSACFast[all-1] | SlimHumanoidEnv | - | - | | OURSACFast[all-1] | HalfCheetahEnv | - | - | | OURSACFast[all-1] | AntEnv | - | - | | OURSACFast[all-2] | CrippleAntEnv | - | - | | OURSACFast[all-2] | CrippleHalfCheetahEnv | - | - | | OURSACFast[all-2] | SlimHumanoidEnv | - | - | | OURSACFast[all-2] | HalfCheetahEnv | - | - | | OURSACFast[all-2] | AntEnv | - | - |