In [5]:
import numpy as np
import seaborn as sns
from matplotlib import pyplot as plt
from matplotlib import animation
import imageio
In [6]:
class Sender:
        
    def __init__(self, n_inputs: int, n_messages: int, eps: float = 1e-6):
        self.n_messages = n_messages
        self.message_weights = np.zeros((n_inputs, n_messages))
        self.message_weights.fill(eps)
        self.last_situation = (0, 0)
        
    def send_message(self, input: int) -> int:
        probs = np.exp(self.message_weights[input, :])/np.sum(np.exp(self.message_weights[input, :]))
        message = np.random.choice(self.n_messages, p=probs)
        self.last_situation = (input, message)
        return message

    def learn_from_feedback(self, reward: int) -> None:
        self.message_weights[self.last_situation] += reward
In [7]:
class Receiver:
        
    def __init__(self, n_messages: int, n_actions: int, eps: float = 1e-6):
        self.n_actions = n_actions
        self.action_weights = np.ndarray((n_messages, n_actions))
        self.action_weights.fill(eps)
        self.last_situation = (0, 0)
        
    def act(self, message: int) -> int:
        probs = np.exp(self.action_weights[message, :])/np.sum(np.exp(self.action_weights[message, :]))
        action = np.random.choice(self.n_actions, p=probs)
        self.last_situation = (message, action)
        return action

    def learn_from_feedback(self, reward: int) -> None:
        self.action_weights[self.last_situation] += reward
In [10]:
class World:
    def __init__(self, n_states: int, seed: int = 1701):
        self.n_states = n_states
        self.state = 0
        self.rng = np.random.RandomState(seed)
        
    def emit_state(self) -> int:
        self.state = self.rng.randint(self.n_states)
        return self.state
    
    def evaluate_action(self, action: int) -> int:
        return 1 if action == self.state else -1
In [11]:
sender, receiver = Sender(10, 10), Receiver(10, 10)
world = World(10)
past_rewards = 0
matrices = []
for epoch in range(3000):
    world_state = world.emit_state()
    message = sender.send_message(world_state)
    action = receiver.act(message)
    reward = world.evaluate_action(action)
    receiver.learn_from_feedback(reward)
    sender.learn_from_feedback(reward)
    past_rewards += reward
    if epoch % 25 == 0:
        plt.tight_layout(pad=0)
        plot = sns.heatmap(
            np.exp(receiver.action_weights)/np.exp(receiver.action_weights).sum(axis=0), 
            square=True, cbar=False, annot=True, fmt='.1f'
        ).get_figure()
        plt.xlabel('messages')
        plt.ylabel('actions')
        plt.title(f'Receiver\'s weights, rollout {epoch}')
        plt.savefig(f"receiver_{epoch}.png")
        plt.clf()
        
        plot = sns.heatmap(
            np.exp(sender.message_weights)/np.exp(sender.message_weights).sum(axis=0), 
            square=True, cbar=False,annot=True, fmt='.1f'
        ).get_figure()
        plt.xlabel('world states')
        plt.ylabel('messages')
        plt.title(f'Sender\'s weights, rollout {epoch}')
        plt.savefig(f"sender_{epoch}.png")
        plt.clf()
           
    if epoch % 100 == 0:
        print(f'Epoch {epoch}, last 100 epochs reward: {past_rewards/100}')
        print(world_state, message, action, reward)
        past_rewards = 0

print("Observation to message mapping:")
print(sender.message_weights.argmax(1))
print("Message to action mapping:")
print(receiver.action_weights.argmax(1))
Epoch 0, last 100 epochs reward: -0.01
4 4 0 -1
Epoch 100, last 100 epochs reward: -0.74
1 7 2 -1
Epoch 200, last 100 epochs reward: -0.56
9 8 9 1
Epoch 300, last 100 epochs reward: -0.7
9 8 9 1
Epoch 400, last 100 epochs reward: -0.7
6 8 9 -1
Epoch 500, last 100 epochs reward: -0.5
4 7 6 -1
Epoch 600, last 100 epochs reward: -0.44
8 4 0 -1
Epoch 700, last 100 epochs reward: -0.42
4 2 6 -1
Epoch 800, last 100 epochs reward: -0.1
7 9 2 -1
Epoch 900, last 100 epochs reward: -0.22
1 5 6 -1
Epoch 1000, last 100 epochs reward: -0.02
4 6 4 1
Epoch 1100, last 100 epochs reward: 0.34
6 5 6 1
Epoch 1200, last 100 epochs reward: 0.28
9 2 6 -1
Epoch 1300, last 100 epochs reward: 0.54
7 6 4 -1
Epoch 1400, last 100 epochs reward: 0.54
4 6 4 1
Epoch 1500, last 100 epochs reward: 0.62
7 8 6 -1
Epoch 1600, last 100 epochs reward: 0.7
6 5 6 1
Epoch 1700, last 100 epochs reward: 0.62
2 9 2 1
Epoch 1800, last 100 epochs reward: 0.86
4 6 4 1
Epoch 1900, last 100 epochs reward: 0.82
9 2 9 1
Epoch 2000, last 100 epochs reward: 0.78
2 9 2 1
Epoch 2100, last 100 epochs reward: 0.94
0 0 0 1
Epoch 2200, last 100 epochs reward: 1.0
3 1 3 1
Epoch 2300, last 100 epochs reward: 1.0
8 4 8 1
Epoch 2400, last 100 epochs reward: 1.0
1 3 1 1
Epoch 2500, last 100 epochs reward: 1.0
8 4 8 1
Epoch 2600, last 100 epochs reward: 1.0
2 9 2 1
Epoch 2700, last 100 epochs reward: 1.0
8 4 8 1
Epoch 2800, last 100 epochs reward: 1.0
6 5 6 1
Epoch 2900, last 100 epochs reward: 1.0
6 5 6 1
Observation to message mapping:
[0 3 9 1 6 8 5 7 4 2]
Message to action mapping:
[0 3 9 1 8 6 4 7 5 2]
<Figure size 432x288 with 0 Axes>
In [27]:
def make_gif(filename_base):
    images = []
    for filename in [f'{filename_base}_{i}.png' for i in range(3000) if i % 25 == 0]:
        images.append(imageio.imread(filename))
    imageio.mimsave(f'{filename_base}.gif', images)
    
make_gif('sender')
make_gif('receiver')