@sven1977 sorry, missed a sentence. Do you know if in the example class, episode.last_observation_for when run at the end of the episode is correctly calling period by period rather than returning the always reset observations?
This is my current Callback class. The problem arises even in on_episode_step despite my observations being updated in the actual MultiAgentEnv class step function.
class MyCallBack(DefaultCallbacks):
"""
Modified callback will return graphs of consumption and eventually policy function when run.
"""
def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[str, Policy],episode: MultiAgentEpisode, env_index: int, **kwargs):
# Make sure this episode has just been started (only initial obs
# logged so far).
assert episode.length == 0, \
"ERROR: `on_episode_start()` callback should be called right " \
"after env reset!"
print("episode {} (env-idx={}) started.".format(
episode.episode_id, env_index))
episode.user_data["consumption"] = []
episode.user_data["savings"] = []
episode.user_data["assets"] = []
episode.user_data["net_savings"]= []
episode.hist_data["consumption"] = []
episode.hist_data["savings"] = []
episode.hist_data["net_savings"] = []
episode.hist_data["assets"] = []
episode.user_data["all_assets"] = []
episode.hist_data["all_assets"] = []
def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv,
episode: MultiAgentEpisode, env_index: int, **kwargs):
# Make sure this episode is ongoing.
assert episode.length > 0, \
"ERROR: `on_episode_step()` callback should not be called right " \
"after env reset!"
consumption= episode.last_action_for(str(0))[0]
assets = episode.last_observation_for(str(0))[0]
price = episode.last_observation_for(str(0))[1]
income = episode.last_observation_for(str(0))[2]
wage = episode.last_observation_for(str(0))[6]
interest = episode.last_observation_for(str(0))[5]
net_savings: float = interest*assets+income-consumption
savings: float = price*assets+income-consumption
for i in range(0,AGENT_NUM):
all_asset_temp = episode.last_observation_for(str(i))[0]
episode.user_data["all_assets"].append(all_asset_temp)
episode.user_data["consumption"].append(consumption)
episode.user_data["assets"].append(assets)
episode.user_data["savings"].append(savings)
episode.user_data["net_savings"].append(net_savings)
def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv,
policies: Dict[str, Policy], episode: MultiAgentEpisode,
env_index: int, **kwargs):
consumption_mean: float = np.mean(episode.user_data["consumption"])
net_savings_mean: float = np.mean(episode.user_data["net_savings"])
savings_mean: float = np.mean(episode.user_data["savings"])
assets_mean: float = np.mean(episode.user_data["all_assets"])
savings_var: float= np.var(episode.user_data["savings"])
net_savings_var: float = np.var(episode.user_data["net_savings"])
consumption_var: float = np.var(episode.user_data["consumption"])
print("episode {} (env-idx={}) ended with length {} and consumption mean and variance: {}, {}"
" and savings mean and variance {}, {}".format(episode.episode_id, env_index, episode.length, consumption_mean, consumption_var, savings_mean, savings_var))
# Graphs of mean and var over time
episode.custom_metrics["consumption_mean"] = consumption_mean
episode.custom_metrics["consumption_var"] = consumption_var
episode.custom_metrics["net_savings_mean"] = net_savings_mean
episode.custom_metrics["net_savings_var"] = net_savings_var
episode.custom_metrics["savings_mean"] = savings_mean
episode.custom_metrics["savings_var"] = savings_var
episode.custom_metrics["assets_mean"] = assets_mean
episode.hist_data["all_assets"] = episode.user_data["all_assets"]
episode.hist_data["net_savings"] = episode.user_data["net_savings"]
episode.hist_data["assets"] = episode.user_data["assets"]
episode.hist_data["savings"] = episode.user_data["savings"]
episode.hist_data["consumption"] = episode.user_data["consumption"]
# Graphs of Hist over time.
episode.custom_metrics["consumption_hist"] = episode.hist_data["consumption"]
episode.custom_metrics["assets_hist"] = episode.hist_data["assets"]
episode.custom_metrics["savings_hist"] = episode.hist_data["savings"]
episode.custom_metrics["net_savings_hist"] = episode.hist_data["net_savings"]
def on_sample_end(self, *, worker: RolloutWorker, samples: SampleBatch,
**kwargs):
print("returned sample batch of size {}".format(samples.count))
def on_train_result(self, *, trainer, result: dict, **kwargs):
print("trainer.train() result: {} -> {} episodes".format(
trainer, result["episodes_this_iter"]))
# you can mutate the result dict to add new fields to return
result["callback_ok"] = True
def on_learn_on_batch(self, *, policy: Policy, train_batch: SampleBatch,
result: dict, **kwargs) -> None:
# Right before learning, we want to display current policy.
#logits =policy.model.from_batch({"obs": np.linspace([policy.observation_space.low[0],1,1,1,5,1,1],[policy.observation_space.high[0],1,1,1,5,1,1],10000)})
#distributions
#prob_vec = []
#state vec
#obs_vec =np.linspace([policy.observation_space.low[0],1,1,1,5,1,1],[policy.observation_space.high[0],1,1,1,5,1,1],10000)
#for i in range(0,1000):
# logits =policy.model.from_batch({"obs": obs_vec[i]})
# probs = tf.nn.softmax(policy.dist_class(logits,policy.model))
# prob_vec.append(probs)
result["sum_actions_in_train_batch"] = np.sum(train_batch["actions"])
print("policy.learn_on_batch() result: {} -> sum actions: {}".format(
policy, result["sum_actions_in_train_batch"]))
def on_postprocess_trajectory(
self, *, worker: RolloutWorker, episode: MultiAgentEpisode,
agent_id: str, policy_id: str, policies: Dict[str, Policy],
postprocessed_batch: SampleBatch,
original_batches: Dict[str, SampleBatch], **kwargs):
print("postprocessed {} steps".format(postprocessed_batch.count))
if "num_batches" not in episode.custom_metrics:
episode.custom_metrics["num_batches"] = 0
episode.custom_metrics["num_batches"] += 1