-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
164 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
{ | ||
"nbformat": 4, | ||
"nbformat_minor": 0, | ||
"metadata": { | ||
"colab": { | ||
"provenance": [], | ||
"gpuType": "T4" | ||
}, | ||
"kernelspec": { | ||
"name": "python3", | ||
"display_name": "Python 3" | ||
}, | ||
"language_info": { | ||
"name": "python" | ||
}, | ||
"accelerator": "GPU" | ||
}, | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"## NOTE\n", | ||
"1. To run this notebook you need to have colab pro or plus as the free version will run out of memory.\n", | ||
"2. An alternative you can run this notebook locally if u have 16+ gb RAM and good GPU\n", | ||
"3. [More illustrated example](https://colab.research.google.com/github/nengelmann/Fuyu-8B---Exploration/blob/main/Fuyu_8B_Exploration.ipynb#scrollTo=SRz1AjrWPZsk)" | ||
], | ||
"metadata": { | ||
"id": "rSZK0YLhX3Nc" | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"colab": { | ||
"base_uri": "https://localhost:8080/" | ||
}, | ||
"id": "0L1N7cRvobWX", | ||
"outputId": "3cdcf2f6-ce45-4e0e-a5bd-dde653894ec8" | ||
}, | ||
"outputs": [ | ||
{ | ||
"output_type": "stream", | ||
"name": "stdout", | ||
"text": [ | ||
"Collecting transformers\n", | ||
" Downloading transformers-4.35.2-py3-none-any.whl (7.9 MB)\n", | ||
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.9/7.9 MB\u001b[0m \u001b[31m55.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | ||
"\u001b[?25hCollecting accelerate\n", | ||
" Downloading accelerate-0.24.1-py3-none-any.whl (261 kB)\n", | ||
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m261.4/261.4 kB\u001b[0m \u001b[31m33.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | ||
"\u001b[?25hRequirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.13.1)\n", | ||
"Collecting huggingface-hub<1.0,>=0.16.4 (from transformers)\n", | ||
" Downloading huggingface_hub-0.19.4-py3-none-any.whl (311 kB)\n", | ||
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m311.7/311.7 kB\u001b[0m \u001b[31m38.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | ||
"\u001b[?25hRequirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.23.5)\n", | ||
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (23.2)\n", | ||
"Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0.1)\n", | ||
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2023.6.3)\n", | ||
"Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.31.0)\n", | ||
"Collecting tokenizers<0.19,>=0.14 (from transformers)\n", | ||
" Downloading tokenizers-0.15.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.8 MB)\n", | ||
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.8/3.8 MB\u001b[0m \u001b[31m67.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | ||
"\u001b[?25hCollecting safetensors>=0.3.1 (from transformers)\n", | ||
" Downloading safetensors-0.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)\n", | ||
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m94.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | ||
"\u001b[?25hRequirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.66.1)\n", | ||
"Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate) (5.9.5)\n", | ||
"Requirement already satisfied: torch>=1.10.0 in /usr/local/lib/python3.10/dist-packages (from accelerate) (2.1.0+cu118)\n", | ||
"Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.16.4->transformers) (2023.6.0)\n", | ||
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.16.4->transformers) (4.5.0)\n", | ||
"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (1.12)\n", | ||
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.2.1)\n", | ||
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.1.2)\n", | ||
"Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (2.1.0)\n", | ||
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.3.2)\n", | ||
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.4)\n", | ||
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.0.7)\n", | ||
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2023.7.22)\n", | ||
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.10.0->accelerate) (2.1.3)\n", | ||
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.10.0->accelerate) (1.3.0)\n", | ||
"Installing collected packages: safetensors, huggingface-hub, tokenizers, accelerate, transformers\n", | ||
"Successfully installed accelerate-0.24.1 huggingface-hub-0.19.4 safetensors-0.4.0 tokenizers-0.15.0 transformers-4.35.2\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"!pip install transformers accelerate --upgrade" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"from transformers import FuyuForCausalLM, FuyuProcessor, FuyuImageProcessor, AutoTokenizer\n", | ||
"from PIL import Image\n", | ||
"import torch\n", | ||
"\n", | ||
"model_name = \"ybelkada/fuyu-8b-sharded\"\n", | ||
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n", | ||
"\n", | ||
"image_processor = FuyuImageProcessor()\n", | ||
"processor = FuyuProcessor(\n", | ||
" image_processor = image_processor,\n", | ||
" tokenizer = tokenizer\n", | ||
")\n", | ||
"\n", | ||
"model = FuyuForCausalLM.from_pretrained(\n", | ||
" model_name,\n", | ||
" device_map = \"cuda\",\n", | ||
" torch_dtype = torch.float16,\n", | ||
")" | ||
], | ||
"metadata": { | ||
"id": "qBkMsXQ9pOiG" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"text_prompt = \"Generate a coco-style caption.\"\n", | ||
"image_path = \"bus.png\"\n", | ||
"image_pil = Image.open(image_path)\n", | ||
"text_prompt += \"\\n\"" | ||
], | ||
"metadata": { | ||
"id": "UL0eUmCwqZ05" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"model_inputs = processor(\n", | ||
" text=text_prompt,\n", | ||
" images=[image_pil],\n", | ||
" device=\"cuda\"\n", | ||
")\n", | ||
"for k, v in model_inputs.items():\n", | ||
" model_inputs[k] = v.to(\"cuda\")" | ||
], | ||
"metadata": { | ||
"id": "qDpK5pFcXvHW" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"generation_output = model.generate(**model_inputs, max_new_tokens=7)\n", | ||
"generation_text = processor.batch_decode(generation_output[:, -7:], skip_special_tokens=True)\n", | ||
"print(generation_text[0])" | ||
], | ||
"metadata": { | ||
"id": "42d0dwSLXvLe" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
} | ||
] | ||
} |