Skip to content

Code for the paper "Closing the Curious Case of Neural Text Degeneration"

Notifications You must be signed in to change notification settings

mattf1n/basis-aware-threshold

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

13 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

basis-aware-threshold

Code for the paper Closing the Curious Case of Neural Text Degeneration

Setup

Clone and navigate to this repository, then

pip install .
pip install -r requirements.txt

BAT sampling implementation

BAT itself is fairly simple to implement.

import numpy as np, cvxpy as cp

def sample(
        embed: np.array,  # Model embedding matrix (or SVD approximation)
        threshold: float,  # Truncation threshold
        probs: np.array,  # Model output probabilities
    ) -> int:
    for token_id in np.random.choice(len(probs), size=len(probs), p=probs, replace=False)
        if probs[token_id] >= threshold: # Program will be infeasible, no need to run it
            return token_id

        # Construct and attempt to solve the linear program
        exp_delta = 1 / max(1 - threshold, 1e-10)
        p = cp.Variable(probs.shape)
        constraints = [
            embed.T @ p == embed.T @ probs,
            cp.sum(p) == 1,
            p[token_id] == 0,
            p <= probs * exp_delta,
        ]
        problem = cp.Problem(cp.Minimize(0), constraints)
        try:
            problem.solve()
            if problem.status == "infeasible":
                return token_id
        except cp.error.SolverError: # Numerical instability suggests infeasible
            return token_id 
    return np.argmax(probs) # Fall back to greedy

Our implementation can be found in src/parallel.py

Experiments

Parameter matching

To find the BAT parameter that rejects human tokens at the same rate as a threshold sampling parameter, compute the per-token rejection thresholds with scripts/param_matching.py, and analyze the results with scripts/relative_conserve.py (which prints the parameters for each method that reject 25% of human tokens).

MAUVE

Use scripts/generate.py to generate completions to prefixes from data/webtext/. Once the texts are generated, use scripts/compute_mauve.py to compute the MAUVE scores.

Figures

scripts/viz/fig1.py Generate data with scripts/unit_test.py, then generate figures with scripts/viz/fig1.py and scripts/unit_test.py.

scripts/viz/hrr.py scripts/viz/hrr.py.

scripts/viz/relcon_by_cons.py scripts/viz/relcon_by_cons.py.

scripts/viz/tern.py scripts/viz/tern.py.

Cite

@misc{finlayson2023closing,
      title={Closing the Curious Case of Neural Text Degeneration}, 
      author={Matthew Finlayson and John Hewitt and Alexander Koller and Swabha Swayamdipta and Ashish Sabharwal},
      year={2023},
      eprint={2310.01693},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}

Contact

[email protected]

About

Code for the paper "Closing the Curious Case of Neural Text Degeneration"

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages