Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Embed yolo files #831

Open
wants to merge 22 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
refactor code, fix copying key value, add --force
  • Loading branch information
katsu560 committed Jun 23, 2024
commit e18593c339ad9601111d4cb3be95f5a920de2e9f
73 changes: 38 additions & 35 deletions examples/yolo/gguf-addfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,47 +73,41 @@ def get_field_data(reader: GGUFReader, key: str) -> Any:
return decode_field(field)


def copy_with_filename(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new_metadata: Mapping[str, str], filename: str[Any]) -> None:
def copy_with_filename(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, filename: str[Any]) -> None:
logger.debug(f'copy_with_filename: {filename}') #debug
val = filename
for field in reader.fields.values():
# Suppress virtual fields and fields written by GGUFWriter
if field.name == Keys.General.ARCHITECTURE or field.name.startswith('GGUF.'):
logger.debug(f'Suppressing {field.name}')
continue

# Skip old chat templates if we have new ones
if field.name.startswith(Keys.Tokenizer.CHAT_TEMPLATE) and Keys.Tokenizer.CHAT_TEMPLATE in new_metadata:
logger.debug(f'Skipping {field.name}')
continue

old_val = decode_field(field)
val = new_metadata.get(field.name, old_val)

if field.name in new_metadata:
logger.debug(f'Modifying {field.name}: "{old_val}" -> "{val}"')
del new_metadata[field.name]
elif val is not None:
logger.debug(f'Copying {field.name}')

if val is not None:
# Copy existed fields except 'embedded_files'
if not field.name == Keys.EMBEDDED_FILES:
cur_val = decode_field(field)
writer.add_key(field.name)
writer.add_val(val, field.types[0])
writer.add_val(cur_val, field.types[0])
logger.debug(f'Copying {field.name}')
continue

if Keys.Tokenizer.CHAT_TEMPLATE in new_metadata:
logger.debug('Adding chat template(s)')
writer.add_chat_template(new_metadata[Keys.Tokenizer.CHAT_TEMPLATE])
del new_metadata[Keys.Tokenizer.CHAT_TEMPLATE]
# Update embedded_files
val = decode_field(field)
for path in filename:
logger.debug(f'Adding {field.name}: {path}')
val.append(path)

# add filenames to kv
writer.add_array(Keys.EMBEDDED_FILES, filename)
# Add filenames to kv
logger.info(f'* Modifying {Keys.EMBEDDED_FILES} to {val}')
writer.add_array(Keys.EMBEDDED_FILES, val)

for tensor in reader.tensors:
# Dimensions are written in reverse order, so flip them first
shape = np.flipud(tensor.shape)
writer.add_tensor_info(tensor.name, shape, tensor.data.dtype, tensor.data.nbytes, tensor.tensor_type)

# add file info as tensor_info
# Add file info as tensor_info
for path in filename:
logger.debug(f'Adding {path}')
logger.debug(f'Adding tensor_info {path}')
with open(path, "rb") as f:
data = f.read()
data_len = len(data)
Expand All @@ -128,9 +122,9 @@ def copy_with_filename(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new_met
for tensor in reader.tensors:
writer.write_tensor_data(tensor.data)

# write file body as tensor data
# Write file body as tensor data
for path in filename:
logger.debug(f'Adding {path}')
logger.debug(f'Adding tensor data {path}')
with open(path, "rb") as f:
data = f.read()
data_len = len(data)
Expand All @@ -145,6 +139,7 @@ def main() -> None:
parser.add_argument("input", type=str, help="GGUF format model input filename")
parser.add_argument("output", type=str, help="GGUF format model output filename")
parser.add_argument("addfiles", type=str, nargs='+', help="add filenames ...")
parser.add_argument("--force", action="store_true", help="Bypass warnings without confirmation")
parser.add_argument("--verbose", action="store_true", help="Increase output verbosity")
args = parser.parse_args(None if len(sys.argv) > 1 else ["--help"])
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
Expand All @@ -154,6 +149,15 @@ def main() -> None:
arch = get_field_data(reader, Keys.General.ARCHITECTURE)
endianess = get_byteorder(reader)

if os.path.isfile(args.output) and not args.force:
logger.warning('*** Warning *** Warning *** Warning **')
logger.warning(f'* The "{args.output}" GGUF file already exists, it will be overwritten!')
logger.warning('* Enter exactly YES if you are positive you want to proceed:')
response = input('YES, I am sure> ')
if response != 'YES':
logger.info("You didn't enter YES. Okay then, see ya!")
sys.exit(0)

logger.info(f'* Writing: {args.output}')
writer = GGUFWriter(args.output, arch=arch, endianess=endianess)

Expand All @@ -162,14 +166,13 @@ def main() -> None:
logger.debug(f'Setting custom alignment: {alignment}')
writer.data_alignment = alignment

logger.info(f'* Adding: {args.addfiles}')
new_metadata = {}
filename = []
for path in args.addfiles:
filename.append(path)
logger.info(f'* Adding: {path}')
copy_with_filename(reader, writer, new_metadata, filename)

if args.addfiles is not None:
filename = []
for path in args.addfiles:
filename.append(path)
logger.info(f'* Adding: {path}')
copy_with_filename(reader, writer, filename)


if __name__ == '__main__':
main()
Loading