Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix model conversion #76

Merged
merged 2 commits into from
Aug 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -176,12 +176,12 @@ In order to build bark.cpp you have two different options. We recommend using `C
### Prepare data & Run

```bash
# obtain the original bark and encodec weights and place them in ./models
python3 download_weights.py --download-dir ./models

# install Python dependencies
python3 -m pip install -r requirements.txt

# obtain the original bark and encodec weights and place them in ./models
python3 download_weights.py --download-dir ./models

# convert the model to ggml format
python3 convert.py \
--dir-model ./models \
Expand Down
6 changes: 3 additions & 3 deletions download_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch


ENCODEC_PATH = Path("https://dl.fbaipublicfiles.com/encodec/v0/encodec_24khz-d7cc33bc.th")
ENCODEC_PATH = "https://dl.fbaipublicfiles.com/encodec/v0/encodec_24khz-d7cc33bc.th"

REMOTE_MODEL_PATHS = {
"text": {
Expand Down Expand Up @@ -39,11 +39,11 @@

print(" ### Downloading EnCodec weights...")
state_dict = torch.hub.load_state_dict_from_url(
str(ENCODEC_PATH),
ENCODEC_PATH,
map_location="cpu",
check_hash=True
)
with open(out_dir / ENCODEC_PATH.name, "wb") as fout:
with open(out_dir / Path(ENCODEC_PATH).name, "wb") as fout:
torch.save(state_dict, fout)

print("Done.")