Skip to content

Commit

Permalink
feat: support mixed-precision in inference pass (#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
PABannier committed Oct 9, 2023
1 parent 9c3316c commit c150c8e
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 51 deletions.
46 changes: 41 additions & 5 deletions convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,12 @@
parser = argparse.ArgumentParser()
parser.add_argument("--dir-model", type=str, required=True)
parser.add_argument("--out-dir", type=str, required=True)
parser.add_argument("--use-f16", type=bool, default=True)
parser.add_argument("--use-f16", action="store_true")


def parse_codec_model(checkpoint, out_dir, use_f16):
def parse_codec_model(checkpoint, outfile, use_f16):
"""Load encodec model checkpoint."""
outfile = open(out_dir, "wb")
outfile.write(struct.pack("i", 0x67676d6c)) # ggml magic
n_f16, n_f32 = 0, 0

for name in checkpoint.keys():
if "weight_g" in name:
Expand Down Expand Up @@ -81,18 +80,21 @@ def parse_codec_model(checkpoint, out_dir, use_f16):
print(f"Processing variable: {name} with shape: {var_data.shape}")

if use_f16:
if "weight" in name:
if "weight" in name or "embed" in name:
print(" Converting to float16")
var_data = var_data.astype(np.float16)
ftype_cur = 1
n_f16 += 1
else:
print(" Converting to float32")
var_data = var_data.astype(np.float32)
ftype_cur = 0
n_f32 += 1
else:
print(" Converting to float32")
var_data = var_data.astype(np.float32)
ftype_cur = 0
n_f32 += 1

n_dims = len(var_data.shape)
encoded_name = name.encode("utf-8")
Expand All @@ -106,6 +108,31 @@ def parse_codec_model(checkpoint, out_dir, use_f16):

outfile.close()

print("\n")
print(f"n_f16: {n_f16} ({n_f16/(n_f16 + n_f32)*100:.0f}%)")
print(f"n_f32: {n_f32} ({n_f32/(n_f16 + n_f32)*100:.0f}%)")


def parse_hparams(outfile, use_f16):
# for now this is hardcoded as we only support the 24Khz model
in_channels = 1
hidden_dim = 128
n_filters = 32
kernel_size = 7
residual_kernel_size = 3
n_q = 32
n_bins = 1024
ftype = int(use_f16)

outfile.write(struct.pack("i", in_channels))
outfile.write(struct.pack("i", hidden_dim))
outfile.write(struct.pack("i", n_filters))
outfile.write(struct.pack("i", kernel_size))
outfile.write(struct.pack("i", residual_kernel_size))
outfile.write(struct.pack("i", n_q))
outfile.write(struct.pack("i", n_bins))
outfile.write(struct.pack("i", ftype))


if __name__ == "__main__":
args = parser.parse_args()
Expand All @@ -118,6 +145,15 @@ def parse_codec_model(checkpoint, out_dir, use_f16):
outfile = Path(out_dir / "ggml-model.bin")

checkpoint = torch.load(dir_model / "encodec_24khz-d7cc33bc.th", map_location="cpu")

# Step 1: insert ggml magic
outfile = open(outfile, "wb")
outfile.write(struct.pack("i", 0x67676d6c))

# Step 2: insert hyperparameters
parse_hparams(outfile, args.use_f16)

# Step 3: insert weights
parse_codec_model(checkpoint, outfile, args.use_f16)

print("Done.")
Loading

0 comments on commit c150c8e

Please sign in to comment.