Skip to content

Commit

Permalink
Merge pull request #97 from honglu2875/honglu/dev
Browse files Browse the repository at this point in the history
Allow YaRN finetuning; Allow reading safetensors
  • Loading branch information
honglu2875 committed Feb 4, 2024
2 parents 4ed0a5b + 2233495 commit 571a797
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 18 deletions.
7 changes: 4 additions & 3 deletions aria/data/midi.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Utils for data/MIDI processing."""

import hashlib
import json
import re
Expand Down Expand Up @@ -319,9 +320,9 @@ def _extract_track_data(track: mido.MidiTrack):
if len(notes_to_close) > 0 and len(notes_to_keep) > 0:
# Note-on on the same tick but we already closed
# some previous notes -> it will continue, keep it.
last_note_on[
(message.note, message.channel)
] = notes_to_keep
last_note_on[(message.note, message.channel)] = (
notes_to_keep
)
else:
# Remove the last note on for this instrument
del last_note_on[(message.note, message.channel)]
Expand Down
1 change: 1 addition & 0 deletions aria/model/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Includes (PyTorch) transformer model and config classes."""

from dataclasses import dataclass
from typing import Optional, Union

Expand Down
4 changes: 2 additions & 2 deletions aria/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def sample(args):
from aria.tokenizer import RelTokenizer, AbsTokenizer
from aria.sample import greedy_sample
from aria.data.midi import MidiDict
from aria.utils import midi_to_audio
from aria.utils import midi_to_audio, _load_weight

if not cuda_is_available():
print("CUDA device is not available. Using CPU instead.")
Expand All @@ -150,7 +150,7 @@ def sample(args):
)

ckpt_path = _get_ckpt_path(args.c) # let user input path if not provided
model_state = torch.load(ckpt_path, map_location=device)
model_state = _load_weight(ckpt_path, device=device.type)
model_name = _get_model_name(
args.m, model_state
) # infer model name if not provided
Expand Down
1 change: 1 addition & 0 deletions aria/sample.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Contains generation/sampling code"""

# This file contains code from https://github.com/facebookresearch/llama which
# is available under the following license:

Expand Down
35 changes: 22 additions & 13 deletions aria/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
PretrainingDataset,
FinetuningDataset,
)
from aria.utils import _load_weight


# ----- USAGE -----
Expand Down Expand Up @@ -669,12 +670,16 @@ def resume_train(
else:
raise Exception

assert (
train_dataloader.dataset.max_seq_len == model_config.max_seq_len
), "max_seq_len differs between datasets and model config"
assert (
val_dataloader.dataset.max_seq_len == model_config.max_seq_len
), "max_seq_len differs between datasets and model config"
if (
model_config.yarn_config is None
or model_config.yarn_config.scale <= 1.0
):
assert (
train_dataloader.dataset.max_seq_len == model_config.max_seq_len
), "max_seq_len differs between datasets and model config"
assert (
val_dataloader.dataset.max_seq_len == model_config.max_seq_len
), "max_seq_len differs between datasets and model config"

(
model,
Expand Down Expand Up @@ -781,7 +786,7 @@ def train(
logger.info(f"Loaded model with config: {load_model_config(model_name)}")
if mode == "finetune":
try:
model.load_state_dict(torch.load(finetune_cp_path))
model.load_state_dict(_load_weight(finetune_cp_path))
except Exception as e:
raise Exception(
f"Failed to load checkpoint: {e}\n"
Expand Down Expand Up @@ -823,12 +828,16 @@ def train(
else:
raise Exception

assert (
train_dataloader.dataset.max_seq_len == model_config.max_seq_len
), "max_seq_len differs between datasets and model config"
assert (
val_dataloader.dataset.max_seq_len == model_config.max_seq_len
), "max_seq_len differs between datasets and model config"
if (
model_config.yarn_config is None
or model_config.yarn_config.scale <= 1.0
):
assert (
train_dataloader.dataset.max_seq_len == model_config.max_seq_len
), "max_seq_len differs between datasets and model config"
assert (
val_dataloader.dataset.max_seq_len == model_config.max_seq_len
), "max_seq_len differs between datasets and model config"

(
model,
Expand Down
15 changes: 15 additions & 0 deletions aria/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,18 @@ def midi_to_audio(mid_path: str, soundfont_path: str | None = None):
print(e)

print(f"Saved files: \n{wav_path}\n{mp3_path}")


def _load_weight(ckpt_path: str, device="cpu"):
if ckpt_path.endswith("safetensors"):
try:
from safetensors.torch import load_file
except ImportError as e:
raise ImportError(
f"Please install safetensors in order to read from the checkpoint: {ckpt_path}"
) from e
return load_file(ckpt_path, device=device)
else:
import torch

return torch.load(ckpt_path, map_location=device)

0 comments on commit 571a797

Please sign in to comment.