import numpy as np
import seaborn as sns
from matplotlib import pyplot as plt
from matplotlib import animation
import imageio
class Sender:
def __init__(self, n_inputs: int, n_messages: int, eps: float = 1e-6):
self.n_messages = n_messages
self.message_weights = np.zeros((n_inputs, n_messages))
self.message_weights.fill(eps)
self.last_situation = (0, 0)
def send_message(self, input: int) -> int:
probs = np.exp(self.message_weights[input, :])/np.sum(np.exp(self.message_weights[input, :]))
message = np.random.choice(self.n_messages, p=probs)
self.last_situation = (input, message)
return message
def learn_from_feedback(self, reward: int) -> None:
self.message_weights[self.last_situation] += reward
class Receiver:
def __init__(self, n_messages: int, n_actions: int, eps: float = 1e-6):
self.n_actions = n_actions
self.action_weights = np.ndarray((n_messages, n_actions))
self.action_weights.fill(eps)
self.last_situation = (0, 0)
def act(self, message: int) -> int:
probs = np.exp(self.action_weights[message, :])/np.sum(np.exp(self.action_weights[message, :]))
action = np.random.choice(self.n_actions, p=probs)
self.last_situation = (message, action)
return action
def learn_from_feedback(self, reward: int) -> None:
self.action_weights[self.last_situation] += reward
class World:
def __init__(self, n_states: int, seed: int = 1701):
self.n_states = n_states
self.state = 0
self.rng = np.random.RandomState(seed)
def emit_state(self) -> int:
self.state = self.rng.randint(self.n_states)
return self.state
def evaluate_action(self, action: int) -> int:
return 1 if action == self.state else -1
sender, receiver = Sender(10, 10), Receiver(10, 10)
world = World(10)
past_rewards = 0
matrices = []
for epoch in range(3000):
world_state = world.emit_state()
message = sender.send_message(world_state)
action = receiver.act(message)
reward = world.evaluate_action(action)
receiver.learn_from_feedback(reward)
sender.learn_from_feedback(reward)
past_rewards += reward
if epoch % 25 == 0:
plt.tight_layout(pad=0)
plot = sns.heatmap(
np.exp(receiver.action_weights)/np.exp(receiver.action_weights).sum(axis=0),
square=True, cbar=False, annot=True, fmt='.1f'
).get_figure()
plt.xlabel('messages')
plt.ylabel('actions')
plt.title(f'Receiver\'s weights, rollout {epoch}')
plt.savefig(f"receiver_{epoch}.png")
plt.clf()
plot = sns.heatmap(
np.exp(sender.message_weights)/np.exp(sender.message_weights).sum(axis=0),
square=True, cbar=False,annot=True, fmt='.1f'
).get_figure()
plt.xlabel('world states')
plt.ylabel('messages')
plt.title(f'Sender\'s weights, rollout {epoch}')
plt.savefig(f"sender_{epoch}.png")
plt.clf()
if epoch % 100 == 0:
print(f'Epoch {epoch}, last 100 epochs reward: {past_rewards/100}')
print(world_state, message, action, reward)
past_rewards = 0
print("Observation to message mapping:")
print(sender.message_weights.argmax(1))
print("Message to action mapping:")
print(receiver.action_weights.argmax(1))
def make_gif(filename_base):
images = []
for filename in [f'{filename_base}_{i}.png' for i in range(3000) if i % 25 == 0]:
images.append(imageio.imread(filename))
imageio.mimsave(f'{filename_base}.gif', images)
make_gif('sender')
make_gif('receiver')