Skip to content

bkj/pbt

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

pbt: Population Based Training

Code to replicate figure 2 of Population Based Training of Neural Networks, Jaderberg et al

from __future__ import print_function

import numpy as np

import torch
from torch.optim import SGD
from torch.autograd import Variable
from torch.nn import Parameter

from matplotlib import pyplot as plt
plt.rcParams['figure.figsize'] = (10, 10)


_ = np.random.seed(123)
_ = torch.manual_seed(123)
# Define PBT worker

class Worker():
    
    def __init__(self, theta, h, objective, surrogate_objective, id):
        
        self.theta = theta
        self.h = h
        self.objective = objective
        self.surrogate_objective = surrogate_objective
        self.id = id
        
        self._opt = SGD([theta], lr=0.01)
        self._history = {"theta" : [], "h" : [], "score" : []}
    
    @property
    def history(self):
        return {
            "theta" : np.vstack(self._history['theta']),
            "h"     : np.vstack(self._history['h']),
            "score" : np.array(self._history['score']),
        }
    
    def _log(self):
        self._history['theta'].append(self.theta.data.numpy().copy())
        self._history['h'].append(self.h.data.numpy().copy())
        self._history['score'].append(self.eval())
    
    
    def step(self):
        """ Take an optimization step, given current hyperparemeters and surrogate objective """
        self._log()
        
        self._opt.zero_grad()
        surrogate_loss = -1 * self.surrogate_objective(self.theta, self.h)
        surrogate_loss.backward()
        self._opt.step()
        
    def eval(self):
        """ Evalute actual objective -- eg measure accuracy on the hold-out set """
        
        return self.objective(self.theta).data[0]
    
    def exploit(self, population):
        """ Copy theta from best member of the population """
        
        current_scores = [{
            "id": worker.id,
            "score": worker.eval()
        } for worker in population]
        
        best_worker = sorted(current_scores, key=lambda x: x['score'])[-1]
        
        if best_worker['id'] != self.id:
            self.theta.data.set_(population[best_worker['id']].theta.data.clone())
    
    def explore(self, sd=0.1):
        """ Add normal noise to hyperparameter vector """
        
        self.h.add_(Variable(torch.randn(2) * sd))
def run_experiment(do_explore=False, do_exploit=False, interval=5, n_steps=200):
    
    # Define objective functions
    objective = lambda theta: 1.2 - (theta ** 2).sum()
    surrogate_objective = lambda theta, h: 1.2 - ((h * theta) ** 2).sum()
    
    # Create population
    population = [
        Worker(
            theta=Parameter(torch.FloatTensor([0.9, 0.9])),
            h=Variable(torch.FloatTensor([1.0, 0.0])),
            objective=objective,
            surrogate_objective=surrogate_objective,
            id=0,
        ),
        Worker(
            theta=Parameter(torch.FloatTensor([0.9, 0.9])),
            h=Variable(torch.FloatTensor([0.0, 1.0])),
            objective=objective,
            surrogate_objective=surrogate_objective,
            id=1,
        ),
    ]
    
    # Train
    for step in range(n_steps):
        for worker in population:
            if not (step + 1) % interval:
                
                if do_exploit:
                    worker.exploit(population)
                    
                if do_explore:
                    worker.explore()
            
            worker.step()
    
    return population
# Run experiments w/ various PBT settings
pbt = run_experiment(do_explore=True, do_exploit=True) # Explore and exploit
explore = run_experiment(do_explore=True, do_exploit=False) # Explore only
exploit = run_experiment(do_explore=False, do_exploit=True) # Exploit only
grid = run_experiment(do_explore=False, do_exploit=False) # Independent training runs -- eg, regular grid search
def plot_score(ax, workers, run_name):
    """ Plot performance """
    for worker in workers:
        history = worker.history
        _ = ax.plot(history['score'], label="%s worker %d" % (run_name, worker.id), alpha=0.5)
    
    _ = ax.set_title(run_name)
    _ = ax.set_ylim(-1, 1.3)
    _ = ax.axhline(1.2, c='lightgrey')

fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(nrows=2, ncols=2)
plot_score(ax1, pbt, 'pbt')
plot_score(ax2, explore, 'explore')
plot_score(ax3, exploit, 'exploit')
plot_score(ax4, grid, 'grid')
_ = plt.tight_layout(pad=1)
plt.show()

png

def plot_theta(ax, workers, run_name):
    """ Plot values of theta """
    for worker in workers:
        history = worker.history
        _ = ax.scatter(history['theta'][:,0], history['theta'][:,1], 
            s=2, alpha=0.5, label="%s worker %d" % (run_name, worker.id))
    
    _ = ax.set_title(run_name)
    _ = ax.set_xlim(0, 1)
    _ = ax.set_ylim(0, 1)

fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(nrows=2, ncols=2)
plot_theta(ax1, pbt, 'pbt')
plot_theta(ax2, explore, 'explore')
plot_theta(ax3, exploit, 'exploit')
plot_theta(ax4, grid, 'grid')
_ = plt.tight_layout(pad=1)
plt.show()

png