ACT代码详解

张开发
2026/4/12 14:52:44 15 分钟阅读

分享文章

ACT代码详解
一、用record_sim_episodes.py生成数据import time import os import numpy as np import argparse import matplotlib.pyplot as plt import h5py from constants import PUPPET_GRIPPER_POSITION_NORMALIZE_FN, SIM_TASK_CONFIGS from ee_sim_env import make_ee_sim_env from sim_env import make_sim_env, BOX_POSE from scripted_policy import PickAndTransferPolicy, InsertionPolicy import IPython e IPython.embed def main(args): Generate demonstration data in simulation. First rollout the policy (defined in ee space) in ee_sim_env. Obtain the joint trajectory. Replace the gripper joint positions with the commanded joint position. Replay this joint trajectory (as action sequence) in sim_env, and record all observations. Save this episode of data, and continue to next episode of data collection. task_name args[task_name] dataset_dir args[dataset_dir] num_episodes args[num_episodes] onscreen_render args[onscreen_render] inject_noise False render_cam_name angle if not os.path.isdir(dataset_dir): os.makedirs(dataset_dir, exist_okTrue) episode_len SIM_TASK_CONFIGS[task_name][episode_len] camera_names SIM_TASK_CONFIGS[task_name][camera_names] if task_name sim_transfer_cube_scripted: policy_cls PickAndTransferPolicy elif task_name sim_insertion_scripted: policy_cls InsertionPolicy else: raise NotImplementedError success [] for episode_idx in range(num_episodes): print(f{episode_idx}) print(Rollout out EE space scripted policy) # setup the environment env make_ee_sim_env(task_name) ts env.reset() episode [ts] policy policy_cls(inject_noise) # setup plotting if onscreen_render: ax plt.subplot() plt_img ax.imshow(ts.observation[images][render_cam_name]) plt.ion() for step in range(episode_len): action policy(ts) ts env.step(action) episode.append(ts) if onscreen_render: plt_img.set_data(ts.observation[images][render_cam_name]) plt.pause(0.002) plt.close() episode_return np.sum([ts.reward for ts in episode[1:]]) episode_max_reward np.max([ts.reward for ts in episode[1:]]) if episode_max_reward env.task.max_reward: print(f{episode_idx} Successful, {episode_return}) else: print(f{episode_idx} Failed) joint_traj [ts.observation[qpos] for ts in episode] # replace gripper pose with gripper control gripper_ctrl_traj [ts.observation[gripper_ctrl] for ts in episode] for joint, ctrl in zip(joint_traj, gripper_ctrl_traj): left_ctrl PUPPET_GRIPPER_POSITION_NORMALIZE_FN(ctrl[0]) right_ctrl PUPPET_GRIPPER_POSITION_NORMALIZE_FN(ctrl[2]) joint[6] left_ctrl joint[67] right_ctrl subtask_info episode[0].observation[env_state].copy() # box pose at step 0 # clear unused variables del env del episode del policy # setup the environment print(Replaying joint commands) env make_sim_env(task_name) BOX_POSE[0] subtask_info # make sure the sim_env has the same object configurations as ee_sim_env ts env.reset() episode_replay [ts] # setup plotting if onscreen_render: ax plt.subplot() plt_img ax.imshow(ts.observation[images][render_cam_name]) plt.ion() for t in range(len(joint_traj)): # note: this will increase episode length by 1 action joint_traj[t] ts env.step(action) episode_replay.append(ts) if onscreen_render: plt_img.set_data(ts.observation[images][render_cam_name]) plt.pause(0.02) episode_return np.sum([ts.reward for ts in episode_replay[1:]]) episode_max_reward np.max([ts.reward for ts in episode_replay[1:]]) if episode_max_reward env.task.max_reward: success.append(1) print(f{episode_idx} Successful, {episode_return}) else: success.append(0) print(f{episode_idx} Failed) plt.close() For each timestep: observations - images - each_cam_name (480, 640, 3) uint8 - qpos (14,) float64 - qvel (14,) float64 action (14,) float64 data_dict { /observations/qpos: [], /observations/qvel: [], /action: [], } for cam_name in camera_names: data_dict[f/observations/images/{cam_name}] [] # because the replaying, there will be eps_len 1 actions and eps_len 2 timesteps # truncate here to be consistent joint_traj joint_traj[:-1] episode_replay episode_replay[:-1] # len(joint_traj) i.e. actions: max_timesteps # len(episode_replay) i.e. time steps: max_timesteps 1 max_timesteps len(joint_traj) while joint_traj: action joint_traj.pop(0) ts episode_replay.pop(0) data_dict[/observations/qpos].append(ts.observation[qpos]) data_dict[/observations/qvel].append(ts.observation[qvel]) data_dict[/action].append(action) for cam_name in camera_names: data_dict[f/observations/images/{cam_name}].append(ts.observation[images][cam_name]) # HDF5 t0 time.time() dataset_path os.path.join(dataset_dir, fepisode_{episode_idx}) with h5py.File(dataset_path .hdf5, w, rdcc_nbytes1024 ** 2 * 2) as root: root.attrs[sim] True obs root.create_group(observations) image obs.create_group(images) for cam_name in camera_names: _ image.create_dataset(cam_name, (max_timesteps, 480, 640, 3), dtypeuint8, chunks(1, 480, 640, 3), ) # compressiongzip,compression_opts2,) # compression32001, compression_opts(0, 0, 0, 0, 9, 1, 1), shuffleFalse) qpos obs.create_dataset(qpos, (max_timesteps, 14)) qvel obs.create_dataset(qvel, (max_timesteps, 14)) action root.create_dataset(action, (max_timesteps, 14)) for name, array in data_dict.items(): root[name][...] array print(fSaving: {time.time() - t0:.1f} secs\n) print(fSaved to {dataset_dir}) print(fSuccess: {np.sum(success)} / {len(success)}) if __name__ __main__: parser argparse.ArgumentParser() parser.add_argument(--task_name, actionstore, typestr, helptask_name, requiredTrue) parser.add_argument(--dataset_dir, actionstore, typestr, helpdataset saving dir, requiredTrue) parser.add_argument(--num_episodes, actionstore, typeint, helpnum_episodes, requiredFalse) parser.add_argument(--onscreen_render, actionstore_true) main(vars(parser.parse_args()))1.导入与全局配置import time import os import numpy as np import argparse import matplotlib.pyplot as plt import h5py from constants import PUPPET_GRIPPER_POSITION_NORMALIZE_FN, SIM_TASK_CONFIGS from ee_sim_env import make_ee_sim_env from sim_env import make_sim_env, BOX_POSE from scripted_policy import PickAndTransferPolicy, InsertionPolicyconstants包含任务配置SIM_TASK_CONFIGS每个任务的回合长度、相机名称列表等和夹持器归一化函数PUPPET_GRIPPER_POSITION_NORMALIZE_FN用于将夹持器控制信号映射到关节位置ee_sim_env提供末端执行器空间end-effector space仿真环境其动作空间是末端执行器的位姿如 delta 位移、旋转、夹爪开合不涉及逆运动学sim_env提供完整关节空间joint space仿真环境动作直接是关节角度BOX_POSE全局变量用于指定盒子物体的初始位姿在第二阶段重放前需要与第一阶段末尾的物体状态同步scripted_policy包含两个脚本策略PickAndTransferPolicy和InsertionPolicy它们接收当前观测并输出末端执行器空间的动作2.参数解析与主入口parser argparse.ArgumentParser() parser.add_argument(--task_name, requiredTrue) parser.add_argument(--dataset_dir, requiredTrue) parser.add_argument(--num_episodes, typeint) parser.add_argument(--onscreen_render, actionstore_true) main(vars(parser.parse_args()))task_name必须指定例如sim_transfer_cube_scripted或sim_insertion_scripted决定了使用哪个脚本策略和任务配置dataset_dir保存HDF5文件的目录。num_episodes生成的演示回合数若未指定则使用配置中的默认值代码中未显式默认实际可能由调用者提供。onscreen_render是否在屏幕上实时渲染仿真画面用于调试或观察。3.main函数详细流程task_name args[task_name] dataset_dir args[dataset_dir] num_episodes args[num_episodes] onscreen_render args[onscreen_render] inject_noise False #表示策略不注入噪声 render_cam_name angle #选择渲染时使用的相机视角通常是 angle if not os.path.isdir(dataset_dir): os.makedirs(dataset_dir, exist_okTrue) #从 SIM_TASK_CONFIGS 读取任务特定配置回合长度、相机名称列表 episode_len SIM_TASK_CONFIGS[task_name][episode_len] camera_names SIM_TASK_CONFIGS[task_name][camera_names] #根据任务名实例化对应的脚本策略类后面会用它生成动作 if task_name sim_transfer_cube_scripted: policy_cls PickAndTransferPolicy elif task_name sim_insertion_scripted: policy_cls InsertionPolicy else: raise NotImplementedError success [] for episode_idx in range(num_episodes): print(f{episode_idx}) print(Rollout out EE space scripted policy) # 创建末端执行器空间环境 env make_ee_sim_env(task_name) ts env.reset() #重置得到初始时间步ts episode [ts] #用于存储每个时间步的 TimeStep 对象包含观测、奖励、是否结束等信息 policy policy_cls(inject_noise)#初始化策略对象不注入噪声 # setup plotting if onscreen_render: ax plt.subplot() plt_img ax.imshow(ts.observation[images][render_cam_name]) plt.ion() for step in range(episode_len): action policy(ts)#策略根据当前时间步 ts 输出动作末端执行器空间动作 ts env.step(action) episode.append(ts) if onscreen_render: plt_img.set_data(ts.observation[images][render_cam_name]) plt.pause(0.002) plt.close() #计算整个回合的总奖励和最大奖励 episode_return np.sum([ts.reward for ts in episode[1:]]) episode_max_reward np.max([ts.reward for ts in episode[1:]]) if episode_max_reward env.task.max_reward: print(f{episode_idx} Successful, {episode_return}) else: print(f{episode_idx} Failed) #从 episode 中提取所有时间步的关节位置和夹持器控制信号 joint_traj [ts.observation[qpos] for ts in episode] gripper_ctrl_traj [ts.observation[gripper_ctrl] for ts in episode] for joint, ctrl in zip(joint_traj, gripper_ctrl_traj): left_ctrl PUPPET_GRIPPER_POSITION_NORMALIZE_FN(ctrl[0]) right_ctrl PUPPET_GRIPPER_POSITION_NORMALIZE_FN(ctrl[2]) joint[6] left_ctrl joint[67] right_ctrl subtask_info episode[0].observation[env_state].copy() # box pose at step 0 #清理第一阶段资源 del env del episode del policy # setup the environment print(Replaying joint commands) env make_sim_env(task_name) #创建关节空间环境 sim_env BOX_POSE[0] subtask_info # make sure the sim_env has the same object configurations as ee_sim_env ts env.reset() #重置环境得到初始时间步ts episode_replay [ts] # setup plotting if onscreen_render: ax plt.subplot() plt_img ax.imshow(ts.observation[images][render_cam_name]) plt.ion() for t in range(len(joint_traj)): # note: this will increase episode length by 1 action joint_traj[t] ts env.step(action) episode_replay.append(ts) if onscreen_render: plt_img.set_data(ts.observation[images][render_cam_name]) plt.pause(0.02) episode_return np.sum([ts.reward for ts in episode_replay[1:]]) episode_max_reward np.max([ts.reward for ts in episode_replay[1:]]) if episode_max_reward env.task.max_reward: success.append(1) print(f{episode_idx} Successful, {episode_return}) else: success.append(0) print(f{episode_idx} Failed) plt.close() For each timestep: observations - images - each_cam_name (480, 640, 3) uint8 - qpos (14,) float64 - qvel (14,) float64 action (14,) float64 data_dict { /observations/qpos: [], /observations/qvel: [], /action: [], } for cam_name in camera_names: data_dict[f/observations/images/{cam_name}] [] # because the replaying, there will be eps_len 1 actions and eps_len 2 timesteps # truncate here to be consistent joint_traj joint_traj[:-1] episode_replay episode_replay[:-1] # len(joint_traj) i.e. actions: max_timesteps # len(episode_replay) i.e. time steps: max_timesteps 1 max_timesteps len(joint_traj) while joint_traj: action joint_traj.pop(0) ts episode_replay.pop(0) data_dict[/observations/qpos].append(ts.observation[qpos]) data_dict[/observations/qvel].append(ts.observation[qvel]) data_dict[/action].append(action) for cam_name in camera_names: data_dict[f/observations/images/{cam_name}].append(ts.observation[images][cam_name]) # HDF5 t0 time.time() dataset_path os.path.join(dataset_dir, fepisode_{episode_idx}) with h5py.File(dataset_path .hdf5, w, rdcc_nbytes1024 ** 2 * 2) as root: root.attrs[sim] True obs root.create_group(observations) image obs.create_group(images) for cam_name in camera_names: _ image.create_dataset(cam_name, (max_timesteps, 480, 640, 3), dtypeuint8, chunks(1, 480, 640, 3), ) # compressiongzip,compression_opts2,) # compression32001, compression_opts(0, 0, 0, 0, 9, 1, 1), shuffleFalse) qpos obs.create_dataset(qpos, (max_timesteps, 14)) qvel obs.create_dataset(qvel, (max_timesteps, 14)) action root.create_dataset(action, (max_timesteps, 14)) for name, array in data_dict.items(): root[name][...] array print(fSaving: {time.time() - t0:.1f} secs\n) print(fSaved to {dataset_dir}) print(fSuccess: {np.sum(success)} / {len(success)}) if __name__ __main__: parser argparse.ArgumentParser() parser.add_argument(--task_name, actionstore, typestr, helptask_name, requiredTrue) parser.add_argument(--dataset_dir, actionstore, typestr, helpdataset saving dir, requiredTrue) parser.add_argument(--num_episodes, actionstore, typeint, helpnum_episodes, requiredFalse) parser.add_argument(--onscreen_render, actionstore_true) main(vars(parser.parse_args()))二、训练模型imitate_episodes.pyimport torch import numpy as np import os import pickle import argparse import matplotlib.pyplot as plt from copy import deepcopy from tqdm import tqdm from einops import rearrange from constants import DT #constants定义任务配置、物理时间步长、夹爪开度等常量 from constants import PUPPET_GRIPPER_JOINT_OPEN from utils import load_data # data functions from utils import sample_box_pose, sample_insertion_pose # robot functions from utils import compute_dict_mean, set_seed, detach_dict # helper functions from policy import ACTPolicy, CNNMLPPolicy #定义ACTPolicy和CNNMLPPolicy类 from visualize_episodes import save_videos #保存 rollout 视频 from sim_env import BOX_POSE #仿真或真实机器人环境 import IPython e IPython.embed #根据 eval 标志决定执行训练还是评估流程它构建全局配置 config并在训练时加载数据、训练模型、保存最佳模型 #在评估时加载已训练模型并在环境中 rollout def main(args): set_seed(1) # command line parameters is_eval args[eval] #是否仅评估不训练 ckpt_dir args[ckpt_dir] #模型保存/加载目录 policy_class args[policy_class] #策略类型ACT 或 CNNMLP onscreen_render args[onscreen_render] #评估时是否实时渲染 task_name args[task_name] batch_size_train args[batch_size] batch_size_val args[batch_size]#训练和验证的 batch size num_epochs args[num_epochs]#训练轮数 # get task parameters is_sim task_name[:4] sim_ if is_sim: from constants import SIM_TASK_CONFIGS task_config SIM_TASK_CONFIGS[task_name] else: from aloha_scripts.constants import TASK_CONFIGS task_config TASK_CONFIGS[task_name] dataset_dir task_config[dataset_dir] num_episodes task_config[num_episodes] episode_len task_config[episode_len] camera_names task_config[camera_names] # fixed parameters state_dim 14 lr_backbone 1e-5 backbone resnet18 if policy_class ACT: enc_layers 4 dec_layers 7 nheads 8 policy_config {lr: args[lr], num_queries: args[chunk_size], kl_weight: args[kl_weight], hidden_dim: args[hidden_dim], dim_feedforward: args[dim_feedforward], lr_backbone: lr_backbone, backbone: backbone, enc_layers: enc_layers, dec_layers: dec_layers, nheads: nheads, camera_names: camera_names, } elif policy_class CNNMLP: policy_config {lr: args[lr], lr_backbone: lr_backbone, backbone : backbone, num_queries: 1, camera_names: camera_names,} else: raise NotImplementedError config { num_epochs: num_epochs, ckpt_dir: ckpt_dir, episode_len: episode_len, state_dim: state_dim, lr: args[lr], policy_class: policy_class, onscreen_render: onscreen_render, policy_config: policy_config, task_name: task_name, seed: args[seed], temporal_agg: args[temporal_agg], camera_names: camera_names, real_robot: not is_sim } if is_eval: ckpt_names [fpolicy_best.ckpt] results [] for ckpt_name in ckpt_names: success_rate, avg_return eval_bc(config, ckpt_name, save_episodeTrue) results.append([ckpt_name, success_rate, avg_return]) for ckpt_name, success_rate, avg_return in results: print(f{ckpt_name}: {success_rate} {avg_return}) print() exit() train_dataloader, val_dataloader, stats, _ load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_size_val) # save dataset stats if not os.path.isdir(ckpt_dir): os.makedirs(ckpt_dir) stats_path os.path.join(ckpt_dir, fdataset_stats.pkl) with open(stats_path, wb) as f: pickle.dump(stats, f) best_ckpt_info train_bc(train_dataloader, val_dataloader, config) best_epoch, min_val_loss, best_state_dict best_ckpt_info # save best checkpoint ckpt_path os.path.join(ckpt_dir, fpolicy_best.ckpt) torch.save(best_state_dict, ckpt_path) print(fBest ckpt, val loss {min_val_loss:.6f} epoch{best_epoch}) def make_policy(policy_class, policy_config): if policy_class ACT: policy ACTPolicy(policy_config) elif policy_class CNNMLP: policy CNNMLPPolicy(policy_config) else: raise NotImplementedError return policy def make_optimizer(policy_class, policy): if policy_class ACT: optimizer policy.configure_optimizers() elif policy_class CNNMLP: optimizer policy.configure_optimizers() else: raise NotImplementedError return optimizer def get_image(ts, camera_names): curr_images [] for cam_name in camera_names: curr_image rearrange(ts.observation[images][cam_name], h w c - c h w) curr_images.append(curr_image) curr_image np.stack(curr_images, axis0) curr_image torch.from_numpy(curr_image / 255.0).float().cuda().unsqueeze(0) return curr_image def eval_bc(config, ckpt_name, save_episodeTrue): set_seed(1000) ckpt_dir config[ckpt_dir] state_dim config[state_dim] real_robot config[real_robot] policy_class config[policy_class] onscreen_render config[onscreen_render] policy_config config[policy_config] camera_names config[camera_names] max_timesteps config[episode_len] task_name config[task_name] temporal_agg config[temporal_agg] onscreen_cam angle # load policy and stats ckpt_path os.path.join(ckpt_dir, ckpt_name) policy make_policy(policy_class, policy_config) loading_status policy.load_state_dict(torch.load(ckpt_path)) print(loading_status) policy.cuda() policy.eval() print(fLoaded: {ckpt_path}) stats_path os.path.join(ckpt_dir, fdataset_stats.pkl) with open(stats_path, rb) as f: stats pickle.load(f) pre_process lambda s_qpos: (s_qpos - stats[qpos_mean]) / stats[qpos_std] post_process lambda a: a * stats[action_std] stats[action_mean] # load environment if real_robot: from aloha_scripts.robot_utils import move_grippers # requires aloha from aloha_scripts.real_env import make_real_env # requires aloha env make_real_env(init_nodeTrue) env_max_reward 0 else: from sim_env import make_sim_env env make_sim_env(task_name) env_max_reward env.task.max_reward query_frequency policy_config[num_queries] if temporal_agg: query_frequency 1 num_queries policy_config[num_queries] max_timesteps int(max_timesteps * 1) # may increase for real-world tasks num_rollouts 50 episode_returns [] highest_rewards [] for rollout_id in range(num_rollouts): rollout_id 0 ### set task if sim_transfer_cube in task_name: BOX_POSE[0] sample_box_pose() # used in sim reset elif sim_insertion in task_name: BOX_POSE[0] np.concatenate(sample_insertion_pose()) # used in sim reset ts env.reset() ### onscreen render if onscreen_render: ax plt.subplot() plt_img ax.imshow(env._physics.render(height480, width640, camera_idonscreen_cam)) plt.ion() ### evaluation loop if temporal_agg: all_time_actions torch.zeros([max_timesteps, max_timestepsnum_queries, state_dim]).cuda() qpos_history torch.zeros((1, max_timesteps, state_dim)).cuda() image_list [] # for visualization qpos_list [] target_qpos_list [] rewards [] with torch.inference_mode(): for t in range(max_timesteps): ### update onscreen render and wait for DT if onscreen_render: image env._physics.render(height480, width640, camera_idonscreen_cam) plt_img.set_data(image) plt.pause(DT) ### process previous timestep to get qpos and image_list obs ts.observation if images in obs: image_list.append(obs[images]) else: image_list.append({main: obs[image]}) qpos_numpy np.array(obs[qpos]) qpos pre_process(qpos_numpy) qpos torch.from_numpy(qpos).float().cuda().unsqueeze(0) qpos_history[:, t] qpos curr_image get_image(ts, camera_names) ### query policy if config[policy_class] ACT: if t % query_frequency 0: all_actions policy(qpos, curr_image) if temporal_agg: all_time_actions[[t], t:tnum_queries] all_actions actions_for_curr_step all_time_actions[:, t] actions_populated torch.all(actions_for_curr_step ! 0, axis1) actions_for_curr_step actions_for_curr_step[actions_populated] k 0.01 exp_weights np.exp(-k * np.arange(len(actions_for_curr_step))) exp_weights exp_weights / exp_weights.sum() exp_weights torch.from_numpy(exp_weights).cuda().unsqueeze(dim1) raw_action (actions_for_curr_step * exp_weights).sum(dim0, keepdimTrue) else: raw_action all_actions[:, t % query_frequency] elif config[policy_class] CNNMLP: raw_action policy(qpos, curr_image) else: raise NotImplementedError ### post-process actions raw_action raw_action.squeeze(0).cpu().numpy() action post_process(raw_action) target_qpos action ### step the environment ts env.step(target_qpos) ### for visualization qpos_list.append(qpos_numpy) target_qpos_list.append(target_qpos) rewards.append(ts.reward) plt.close() if real_robot: move_grippers([env.puppet_bot_left, env.puppet_bot_right], [PUPPET_GRIPPER_JOINT_OPEN] * 2, move_time0.5) # open pass rewards np.array(rewards) episode_return np.sum(rewards[rewards!None]) episode_returns.append(episode_return) episode_highest_reward np.max(rewards) highest_rewards.append(episode_highest_reward) print(fRollout {rollout_id}\n{episode_return}, {episode_highest_reward}, {env_max_reward}, Success: {episode_highest_rewardenv_max_reward}) if save_episode: save_videos(image_list, DT, video_pathos.path.join(ckpt_dir, fvideo{rollout_id}.mp4)) success_rate np.mean(np.array(highest_rewards) env_max_reward) avg_return np.mean(episode_returns) summary_str f\nSuccess rate: {success_rate}\nAverage return: {avg_return}\n\n for r in range(env_max_reward1): more_or_equal_r (np.array(highest_rewards) r).sum() more_or_equal_r_rate more_or_equal_r / num_rollouts summary_str fReward {r}: {more_or_equal_r}/{num_rollouts} {more_or_equal_r_rate*100}%\n print(summary_str) # save success rate to txt result_file_name result_ ckpt_name.split(.)[0] .txt with open(os.path.join(ckpt_dir, result_file_name), w) as f: f.write(summary_str) f.write(repr(episode_returns)) f.write(\n\n) f.write(repr(highest_rewards)) return success_rate, avg_return def forward_pass(data, policy): image_data, qpos_data, action_data, is_pad data image_data, qpos_data, action_data, is_pad image_data.cuda(), qpos_data.cuda(), action_data.cuda(), is_pad.cuda() return policy(qpos_data, image_data, action_data, is_pad) # TODO remove None def train_bc(train_dataloader, val_dataloader, config): num_epochs config[num_epochs] ckpt_dir config[ckpt_dir] seed config[seed] policy_class config[policy_class] policy_config config[policy_config] set_seed(seed) policy make_policy(policy_class, policy_config) policy.cuda() optimizer make_optimizer(policy_class, policy) train_history [] validation_history [] min_val_loss np.inf best_ckpt_info None for epoch in tqdm(range(num_epochs)): print(f\nEpoch {epoch}) # validation with torch.inference_mode(): policy.eval() epoch_dicts [] for batch_idx, data in enumerate(val_dataloader): forward_dict forward_pass(data, policy) epoch_dicts.append(forward_dict) epoch_summary compute_dict_mean(epoch_dicts) validation_history.append(epoch_summary) epoch_val_loss epoch_summary[loss] if epoch_val_loss min_val_loss: min_val_loss epoch_val_loss best_ckpt_info (epoch, min_val_loss, deepcopy(policy.state_dict())) print(fVal loss: {epoch_val_loss:.5f}) summary_string for k, v in epoch_summary.items(): summary_string f{k}: {v.item():.3f} print(summary_string) # training policy.train() optimizer.zero_grad() for batch_idx, data in enumerate(train_dataloader): forward_dict forward_pass(data, policy) # backward loss forward_dict[loss] loss.backward() optimizer.step() optimizer.zero_grad() train_history.append(detach_dict(forward_dict)) epoch_summary compute_dict_mean(train_history[(batch_idx1)*epoch:(batch_idx1)*(epoch1)]) epoch_train_loss epoch_summary[loss] print(fTrain loss: {epoch_train_loss:.5f}) summary_string for k, v in epoch_summary.items(): summary_string f{k}: {v.item():.3f} print(summary_string) if epoch % 100 0: ckpt_path os.path.join(ckpt_dir, fpolicy_epoch_{epoch}_seed_{seed}.ckpt) torch.save(policy.state_dict(), ckpt_path) plot_history(train_history, validation_history, epoch, ckpt_dir, seed) ckpt_path os.path.join(ckpt_dir, fpolicy_last.ckpt) torch.save(policy.state_dict(), ckpt_path) best_epoch, min_val_loss, best_state_dict best_ckpt_info ckpt_path os.path.join(ckpt_dir, fpolicy_epoch_{best_epoch}_seed_{seed}.ckpt) torch.save(best_state_dict, ckpt_path) print(fTraining finished:\nSeed {seed}, val loss {min_val_loss:.6f} at epoch {best_epoch}) # save training curves plot_history(train_history, validation_history, num_epochs, ckpt_dir, seed) return best_ckpt_info def plot_history(train_history, validation_history, num_epochs, ckpt_dir, seed): # save training curves for key in train_history[0]: plot_path os.path.join(ckpt_dir, ftrain_val_{key}_seed_{seed}.png) plt.figure() train_values [summary[key].item() for summary in train_history] val_values [summary[key].item() for summary in validation_history] plt.plot(np.linspace(0, num_epochs-1, len(train_history)), train_values, labeltrain) plt.plot(np.linspace(0, num_epochs-1, len(validation_history)), val_values, labelvalidation) # plt.ylim([-0.1, 1]) plt.tight_layout() plt.legend() plt.title(key) plt.savefig(plot_path) print(fSaved plots to {ckpt_dir}) if __name__ __main__: parser argparse.ArgumentParser() parser.add_argument(--eval, actionstore_true) parser.add_argument(--onscreen_render, actionstore_true) parser.add_argument(--ckpt_dir, actionstore, typestr, helpckpt_dir, requiredTrue) parser.add_argument(--policy_class, actionstore, typestr, helppolicy_class, capitalize, requiredTrue) parser.add_argument(--task_name, actionstore, typestr, helptask_name, requiredTrue) parser.add_argument(--batch_size, actionstore, typeint, helpbatch_size, requiredTrue) parser.add_argument(--seed, actionstore, typeint, helpseed, requiredTrue) parser.add_argument(--num_epochs, actionstore, typeint, helpnum_epochs, requiredTrue) parser.add_argument(--lr, actionstore, typefloat, helplr, requiredTrue) # for ACT parser.add_argument(--kl_weight, actionstore, typeint, helpKL Weight, requiredFalse) parser.add_argument(--chunk_size, actionstore, typeint, helpchunk_size, requiredFalse) parser.add_argument(--hidden_dim, actionstore, typeint, helphidden_dim, requiredFalse) parser.add_argument(--dim_feedforward, actionstore, typeint, helpdim_feedforward, requiredFalse) parser.add_argument(--temporal_agg, actionstore_true) main(vars(parser.parse_args()))

更多文章