Skip to content

Commit

Permalink
Merge branch 'main' into check-streamable
Browse files Browse the repository at this point in the history
  • Loading branch information
norabelrose committed Apr 6, 2023
2 parents faa8496 + 7f0aaf1 commit 5259b26
Show file tree
Hide file tree
Showing 45 changed files with 295 additions and 119 deletions.
51 changes: 51 additions & 0 deletions .devcontainer/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Use Nvidia Ubuntu 20 base (includes CUDA if a supported GPU is present)
# https://hub.docker.com/r/nvidia/cuda
FROM nvidia/cuda:11.6.2-cudnn8-devel-ubuntu20.04@sha256:55211df43bf393d3393559d5ab53283d4ebc3943d802b04546a24f3345825bd9

ARG USERNAME
ARG USER_UID=1000
ARG USER_GID=$USER_UID

# Create the user
# https://code.visualstudio.com/remote/advancedcontainers/add-nonroot-user
RUN groupadd --gid $USER_GID $USERNAME \
&& useradd --uid $USER_UID --gid $USER_GID -m $USERNAME \
&& usermod -a -G video user \
&& apt-get update \
&& apt-get install -y sudo \
&& echo $USERNAME ALL=\(root\) NOPASSWD:ALL > /etc/sudoers.d/$USERNAME \
&& chmod 0440 /etc/sudoers.d/$USERNAME

ENV DEBIAN_FRONTEND=noninteractive

# Install dependencies
RUN apt-get update && \
apt-get -qq -y install \
software-properties-common && \
add-apt-repository ppa:deadsnakes/ppa && \
apt-get update && \
apt-get -qq -y install \
build-essential \
python3.10 \
python3.10-dev \
python3.10-distutils \
python3.10-venv \
curl \
git \
tmux

# Update package list, add the deadsnakes PPA, and install dependencies


ENV PATH="$HOME/.local/bin:$PATH"

# Install pip (we need the latest version not the standard Ubuntu version, to
# support modern wheels)
RUN sudo curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py && python3.10 get-pip.py

# Set python aliases
RUN sudo update-alternatives --install /usr/bin/python python /usr/bin/python3.10 1
RUN sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 1

# User the new user
USER $USERNAME
47 changes: 47 additions & 0 deletions .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
{
"name": "Python 3",
"build": {
"dockerfile": "Dockerfile",
"args": {
"USERNAME": "user"
}
},
"customizations": {
"vscode": {
"settings": {
"python.formatting.autopep8Path": "autopep8",
"python.linting.mypyPath": "mypy",
"terminal.integrated.defaultProfile.linux": "tmux",
"terminal.integrated.profiles.linux": {
"bash": {
"path": "bash",
"icon": "terminal-bash"
},
"tmux": {
"path": "bash",
"args": ["-c", "tmux new -ADs ${PWD##*/}"],
"icon": "terminal-tmux"
}
}
},
"extensions": [
"davidanson.vscode-markdownlint",
"donjayamanne.githistory",
"donjayamanne.python-extension-pack",
"github.vscode-pull-request-github",
"ms-python.python",
"ms-toolsai.jupyter",
"ms-vsliveshare.vsliveshare-pack",
"njpwerner.autodocstring",
"stkb.rewrap",
"streetsidesoftware.code-spell-checker",
"tushortz.python-extended-snippets",
"yzhang.markdown-all-in-one",
"elagil.pre-commit-helper",
"eamodio.gitlens"
]
}
},
"containerUser": "user",
"postCreateCommand": "pip install -e .[dev] --no-warn-script-location"
}
10 changes: 5 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@ repos:
args: ["--unsafe"]
- id: check-added-large-files
- repo: https://github.com/psf/black
rev: 23.1.0
rev: 23.3.0
hooks:
- id: black
- repo: https://github.com/pycqa/flake8
rev: '6.0.0'
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: 'v0.0.257'
hooks:
- id: flake8
args: ["--ignore=E203,F401,W503", --max-line-length=88]
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
- repo: https://github.com/codespell-project/codespell
rev: v2.2.4
hooks:
Expand Down
17 changes: 17 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,16 @@ The hidden states resulting from `elk elicit` are cached as a HuggingFace datase
## Development
Use `pip install pre-commit && pre-commit install` in the root folder before your first commit.

### Devcontainer

[
![Open in Remote - Containers](
https://img.shields.io/static/v1?label=Remote%20-%20Containers&message=Open&color=blue&logo=visualstudiocode
)
](
https://vscode.dev/redirect?url=vscode:https://ms-vscode-remote.remote-containers/cloneInVolume?url=https://github.com/EleutherAI/elk
)

### Run tests
```bash
pytest
Expand All @@ -49,6 +59,13 @@ We use [pyright](https://github.com/microsoft/pyright), which is built into the
pyright
```

### Run the linter
We use [ruff](https://beta.ruff.rs/docs/). It is installed as a pre-commit hook, so you don't have to run it manually.
If you want to run it manually, you can do so with:
```bash
ruff . --fix
```

### Contributing to this repository

If you work on a new feature / fix or some other code task, make sure to create an issue and assign it to yourself (Maybe, even share it in the elk channel of Eleuther's Discord with a small note). In this way, others know you are working on the issue and people won't do the same thing twice 👍 Also others can contact you easily.
4 changes: 3 additions & 1 deletion elk/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from .extraction import extract_hiddens, Extract
from .extraction import Extract, extract_hiddens

__all__ = ["extract_hiddens", "Extract"]
1 change: 0 additions & 1 deletion elk/__main__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Main entry point for `elk`."""

from dataclasses import dataclass
from pathlib import Path
from typing import Union

from simple_parsing import ArgumentParser
Expand Down
5 changes: 3 additions & 2 deletions elk/calibration.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import warnings
from dataclasses import dataclass, field
from torch import Tensor
from typing import NamedTuple

import torch
import warnings
from torch import Tensor


class CalibrationEstimate(NamedTuple):
Expand Down
3 changes: 2 additions & 1 deletion elk/eigsh.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from torch import Tensor
from typing import Literal, Optional

import torch
import torch.nn.functional as F
from torch import Tensor


def lanczos_eigsh(
Expand Down
14 changes: 2 additions & 12 deletions elk/evaluation/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,18 @@
import csv
import os
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Callable, Literal, Optional, cast
from typing import Callable, Literal, Optional

import torch
import torch.multiprocessing as mp
from simple_parsing.helpers import Serializable, field
from torch import Tensor
from tqdm.auto import tqdm

from datasets import DatasetDict
from elk.evaluation.evaluate_log import EvalLog
from elk.extraction.extraction import Extract
from elk.run import Run
from elk.training import Reporter

from ..files import elk_reporter_dir, memorably_named_dir
from ..training.preprocessing import normalize
from ..files import elk_reporter_dir
from ..utils import (
assert_type,
int16_to_float32,
select_train_val_splits,
select_usable_devices,
)

Expand Down
16 changes: 14 additions & 2 deletions elk/extraction/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,16 @@
from .balanced_sampler import BalancedSampler, FewShotSampler
from .extraction import Extract, extract_hiddens, extract
from .generator import _GeneratorConfig, _GeneratorBuilder
from .extraction import Extract, extract, extract_hiddens
from .generator import _GeneratorBuilder, _GeneratorConfig
from .prompt_loading import PromptConfig, yield_prompts

__all__ = [
"_GeneratorBuilder",
"_GeneratorConfig",
"BalancedSampler",
"extract_hiddens",
"extract",
"Extract",
"FewShotSampler",
"PromptConfig",
"yield_prompts",
]
13 changes: 7 additions & 6 deletions elk/extraction/balanced_sampler.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from ..math_util import stochastic_round_constrained
from ..utils import infer_label_column
from ..utils.typing import assert_type
from collections import deque
from dataclasses import dataclass
from datasets import IterableDataset, Features
from itertools import cycle
from random import Random
from typing import Iterable, Iterator, Optional

from datasets import Features, IterableDataset
from torch.utils.data import IterableDataset as TorchIterableDataset
from typing import Iterator, Optional, Iterable

from ..math_util import stochastic_round_constrained
from ..utils import infer_label_column
from ..utils.typing import assert_type


class BalancedSampler(TorchIterableDataset):
Expand Down
7 changes: 3 additions & 4 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@
from typing import Iterable, Literal, Optional, Union

import torch
from simple_parsing import Serializable, field
from transformers import AutoConfig, AutoModel, AutoTokenizer, PreTrainedModel

from datasets import (
Array3D,
ClassLabel,
Expand All @@ -20,11 +17,13 @@
Value,
get_dataset_config_info,
)
from simple_parsing import Serializable, field
from transformers import AutoConfig, AutoModel, AutoTokenizer, PreTrainedModel

from elk.utils.typing import float32_to_int16

from ..utils import (
assert_type,
infer_label_column,
select_train_val_splits,
select_usable_devices,
)
Expand Down
2 changes: 1 addition & 1 deletion elk/extraction/generator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Callable, Optional, Any, Dict
from typing import Any, Callable, Dict, Optional

import datasets
from datasets.splits import NamedSplit
Expand Down
Empty file.
26 changes: 12 additions & 14 deletions elk/extraction/prompt_loading.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
from dataclasses import dataclass
from random import Random
from typing import Any, Iterator, Literal, Optional

from datasets import (
Dataset,
Features,
load_dataset,
)
from datasets.distributed import split_dataset_by_node
from simple_parsing.helpers import Serializable, field

from ..promptsource import DatasetTemplates
from ..utils import (
assert_type,
Expand All @@ -8,20 +20,6 @@
select_train_val_splits,
)
from .balanced_sampler import FewShotSampler
from dataclasses import dataclass
from datasets import (
interleave_datasets,
load_dataset,
ClassLabel,
Dataset,
Features,
IterableDataset,
Sequence,
)
from datasets.distributed import split_dataset_by_node
from random import Random
from simple_parsing.helpers import field, Serializable
from typing import Any, Iterator, Literal, Optional


@dataclass
Expand Down
4 changes: 2 additions & 2 deletions elk/files.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
"""Helper functions for dealing with files."""

from pathlib import Path
import json
import os
import random
from pathlib import Path
from typing import Optional

from simple_parsing import Serializable
import yaml
from simple_parsing import Serializable


def elk_reporter_dir() -> Path:
Expand Down
1 change: 1 addition & 0 deletions elk/logging.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging

from .utils import select_train_val_splits


Expand Down
3 changes: 2 additions & 1 deletion elk/math_util.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from torch import Tensor
import math
import random

import torch
from torch import Tensor


@torch.jit.script
Expand Down
1 change: 1 addition & 0 deletions elk/parsing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import re

from .training.losses import LOSSES


Expand Down
2 changes: 2 additions & 0 deletions elk/promptsource/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from .templates import DatasetTemplates, Template

__all__ = ["DatasetTemplates", "Template"]
12 changes: 6 additions & 6 deletions elk/promptsource/templates.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from collections import Counter, defaultdict
from jinja2 import BaseLoader, Environment, meta
from pathlib import Path
from shutil import rmtree
from typing import Optional
import logging
import os
import random
import uuid
import yaml
from collections import Counter, defaultdict
from pathlib import Path
from shutil import rmtree
from typing import Optional

import yaml
from jinja2 import BaseLoader, Environment, meta

# Truncation of jinja template variables
# 1710 = 300 words x 4.7 avg characters per word + 300 spaces
Expand Down
Loading

0 comments on commit 5259b26

Please sign in to comment.