From e0dd12fe3527a185cf7a90a0ae3b3b11cfe4df29 Mon Sep 17 00:00:00 2001 From: HIT-cwh <2892770585@qq.com> Date: Mon, 17 Jun 2024 14:05:07 +0800 Subject: [PATCH] fix dispatch bugs --- xtuner/model/modules/dispatch/__init__.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/xtuner/model/modules/dispatch/__init__.py b/xtuner/model/modules/dispatch/__init__.py index e69949356..0ff54941c 100644 --- a/xtuner/model/modules/dispatch/__init__.py +++ b/xtuner/model/modules/dispatch/__init__.py @@ -224,6 +224,10 @@ def replace_rote(model): from mmengine import print_log print_log = log_once(print_log) + assert hasattr(model.config, 'rope_theta'), \ + '`rope_theta` should be in the model config.' + rope_theta = model.config.rope_theta + def traverse(module): for name, child in module.named_children(): cls_name = type(child).__name__ @@ -232,8 +236,10 @@ def traverse(module): rote = rote.build() print_log(f'replace {cls_name}', 'current') dim_model = child.inv_freq.shape[0] * 2 - child_new = rote(dim_model, child.max_seq_len_cached).to( - device=child.inv_freq.device, dtype=child.inv_freq.dtype) + child_new = rote(dim_model, child.max_seq_len_cached, + rope_theta).to( + device=child.inv_freq.device, + dtype=child.inv_freq.dtype) setattr(module, name, child_new) else: traverse(child)