Skip to content

Commit

Permalink
Add compatibility test and fix some problems with peft>=0.6.0 (#146)
Browse files Browse the repository at this point in the history
  • Loading branch information
tastelikefeet committed Nov 13, 2023
1 parent b73acb8 commit 35487b0
Show file tree
Hide file tree
Showing 9 changed files with 70 additions and 65 deletions.
3 changes: 2 additions & 1 deletion .dev_scripts/ci_container_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ if [ "$MODELSCOPE_SDK_DEBUG" == "True" ]; then
fi
fi

pip install -r requirements/framework.txt -i https://mirrors.aliyun.com/pypi/simple/
pip install -r requirements/framework.txt -U -i https://mirrors.aliyun.com/pypi/simple/

# test with install
pip install .
pip install auto_gptq -U -i https://mirrors.aliyun.com/pypi/simple/
else
echo "Running case in release image, run case directly!"
fi
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ Quickly fine-tune, infer with LLM, and build a Web-UI.
```bash
git clone https://github.com/modelscope/swift.git
cd swift
pip install .[llm]
pip install .
```

#### Run using Python
Expand Down
2 changes: 1 addition & 1 deletion README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ SWIFT(Scalable lightWeight Infrastructure for Fine-Tuning)是一个可扩展
```bash
git clone https://github.com/modelscope/swift.git
cd swift
pip install .[llm]
pip install .
```

#### 使用python运行
Expand Down
2 changes: 1 addition & 1 deletion examples/pytorch/llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ Experimental environment: A10, 3090, V100, A100, ...
pip config set global.index-url https://mirrors.aliyun.com/pypi/simple/
git clone https://github.com/modelscope/swift.git
cd swift
pip install .[llm]
pip install .
# The following script needs to be executed in this directory.
cd examples/pytorch/llm

Expand Down
2 changes: 1 addition & 1 deletion examples/pytorch/llm/README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
pip config set global.index-url https://mirrors.aliyun.com/pypi/simple/
git clone https://github.com/modelscope/swift.git
cd swift
pip install .[llm]
pip install .
# 下面的脚本需要在此目录下执行
cd examples/pytorch/llm

Expand Down
7 changes: 7 additions & 0 deletions requirements/framework.txt
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
accelerate
charset_normalizer
cpm_kernels
datasets
diffusers>=0.18.0
gradio
jieba
matplotlib
modelscope>=1.9.3
nltk
numpy
optimum
pandas
peft>=0.5.0
requests
rouge
safetensors
sentencepiece
tensorboard
tiktoken
tqdm
transformers>=4.33
transformers_stream_generator
7 changes: 0 additions & 7 deletions requirements/llm.txt

This file was deleted.

19 changes: 1 addition & 18 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,26 +117,9 @@ def gen_packages_items():
return gen_packages_items()


def pack_resource():
# pack resource such as configs and tools
root_dir = 'package/'
if os.path.isdir(root_dir):
shutil.rmtree(root_dir)
os.makedirs(root_dir)

proj_dir = root_dir + 'swift/'
shutil.copytree('./swift', proj_dir)
shutil.copytree('./requirements', 'package/requirements')
shutil.copy('./requirements.txt', 'package/requirements.txt')
shutil.copy('./MANIFEST.in', 'package/MANIFEST.in')
shutil.copy('./README.md', 'package/README.md')


if __name__ == '__main__':
pack_resource()
os.chdir('package')
install_requires, deps_link = parse_requirements('requirements.txt')
extra_requires = {'llm': parse_requirements('requirements/llm.txt')}
extra_requires = {}
all_requires = []

setup(
Expand Down
91 changes: 56 additions & 35 deletions swift/tuners/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from packaging import version
from peft.import_utils import (is_auto_gptq_available, is_bnb_4bit_available,
is_bnb_available)
from peft.utils import get_auto_gptq_quant_linear, get_quantization_config
Expand All @@ -18,6 +19,22 @@
from ..utils.torch_utils import find_sub_module
from .utils import ActivationMixin, SwiftAdapter, SwiftConfig, SwiftOutput


class LinearWrapper:

def __init__(self, module: torch.nn.Module):
self.module = module

def __getattr__(self, name):
return getattr(self.module, name)

def forward(self, *args, **kwargs):
return self.module.forward_origin(*args, **kwargs)

def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)


if is_bnb_available():
import bitsandbytes as bnb

Expand All @@ -28,16 +45,21 @@ class Linear8bitLt(ActivationMixin, _Linear8bitLt):
def __init__(
self,
adapter_name,
in_features,
out_features,
base_layer,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
**kwargs,
):
super(ActivationMixin,
self).__init__(adapter_name, in_features, out_features, r,
lora_alpha, lora_dropout, **kwargs)
if version.parse(peft.__version__) >= version.parse('0.6.0'):
super(ActivationMixin,
self).__init__(adapter_name, LinearWrapper(base_layer),
r, lora_alpha, lora_dropout, **kwargs)
else:
super(ActivationMixin,
self).__init__(adapter_name, base_layer.in_features,
base_layer.out_features, r, lora_alpha,
lora_dropout, **kwargs)
super(Linear8bitLt, self).__init__()

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand All @@ -47,23 +69,29 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:


if is_bnb_4bit_available():
import peft
from peft.tuners.lora import Linear4bit as _Linear4bit

class Linear4bit(ActivationMixin, _Linear4bit):

def __init__(
self,
adapter_name,
in_features,
out_features,
base_layer,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
**kwargs,
):
super(ActivationMixin,
self).__init__(adapter_name, in_features, out_features, r,
lora_alpha, lora_dropout, **kwargs)
if version.parse(peft.__version__) >= version.parse('0.6.0'):
super(ActivationMixin,
self).__init__(adapter_name, LinearWrapper(base_layer),
r, lora_alpha, lora_dropout, **kwargs)
else:
super(ActivationMixin,
self).__init__(adapter_name, base_layer.in_features,
base_layer.out_features, r, lora_alpha,
lora_dropout, **kwargs)
super(Linear4bit, self).__init__()

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -99,50 +127,45 @@ def __init__(
if not self.use_qa_lora else quant_linear_module.infeatures
// self.group_size,
out_features=quant_linear_module.outfeatures)
self.quant_linear_module = quant_linear_module
self.quant_linear_module = LinearWrapper(quant_linear_module)
self.weight = quant_linear_module.qweight
init_lora_weights = kwargs.pop('init_lora_weights', True)
self.update_layer(adapter_name, r, lora_alpha, lora_dropout,
init_lora_weights)
self.active_adapter = adapter_name
self.set_adapter(adapter_name)
super(QuantLinear, self).__init__()
if self.use_qa_lora:
self.qa_pool = torch.nn.AvgPool1d(
self.group_size
) # using pooling layer to conduct sum operation

def call_quant_linear_module(*args, **kwargs):
return quant_linear_module.forward_origin(*args, **kwargs)

self.call_quant_linear_module = call_quant_linear_module
self.quant_linear_module = None

def forward(self, x: torch.Tensor):
result = self.call_quant_linear_module(x)
result = self.quant_linear_module(x)
if not self.is_activated(
) or self.disable_adapters or self.active_adapter not in self.lora_A.keys(
):
) or self.disable_adapters or self.active_adapter[
0] not in self.lora_A.keys():
return result
elif self.r[self.active_adapter] > 0:
elif self.r[self.active_adapter[0]] > 0:
result = result.clone()
if not torch.is_autocast_enabled():
expected_dtype = result.dtype
x = x.to(self.lora_A[self.active_adapter].weight.dtype)
x = x.to(self.lora_A[self.active_adapter[0]].weight.dtype)
if self.use_qa_lora:
x = self.qa_pool(x) * self.group_size
output = (
self.lora_B[self.active_adapter](
self.lora_A[self.active_adapter](self.lora_dropout[
self.active_adapter](x))).to(expected_dtype)
* self.scaling[self.active_adapter])
self.lora_B[self.active_adapter[0]](
self.lora_A[self.active_adapter[0]](
self.lora_dropout[self.active_adapter[0]]
(x))).to(expected_dtype)
* self.scaling[self.active_adapter[0]])
else:
if self.use_qa_lora:
x = self.qa_pool(x) * self.group_size
output = (
self.lora_B[self.active_adapter](
self.lora_A[self.active_adapter](
self.lora_dropout[self.active_adapter](x)))
* self.scaling[self.active_adapter])
self.lora_B[self.active_adapter[0]](
self.lora_A[self.active_adapter[0]](
self.lora_dropout[self.active_adapter[0]](x)))
* self.scaling[self.active_adapter[0]])
result += output
return result

Expand Down Expand Up @@ -302,8 +325,7 @@ def _dynamic_patch_lora(model: torch.nn.Module,
})
lora_module = Linear8bitLt(
'default',
sub_module.in_features,
sub_module.out_features,
sub_module,
bias=hasattr(sub_module, 'bias')
and sub_module.bias is not None,
**eight_bit_kwargs)
Expand All @@ -321,8 +343,7 @@ def _dynamic_patch_lora(model: torch.nn.Module,
})
lora_module = Linear4bit(
'default',
sub_module.in_features,
sub_module.out_features,
sub_module,
bias=hasattr(sub_module, 'bias')
and sub_module.bias is not None,
**four_bit_kwargs)
Expand Down

0 comments on commit 35487b0

Please sign in to comment.