From b17dc1b9042906e1d89ee8b669f702463591167a Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Fri, 24 Oct 2025 19:36:30 +0200 Subject: [PATCH 01/11] Update SFT QLoRA notebook with 14B model --- examples/notebooks/sft_trl_lora_qlora.ipynb | 2214 ++++++++++--------- 1 file changed, 1143 insertions(+), 1071 deletions(-) diff --git a/examples/notebooks/sft_trl_lora_qlora.ipynb b/examples/notebooks/sft_trl_lora_qlora.ipynb index 4ce99e5ad0..8c0bfe62f6 100644 --- a/examples/notebooks/sft_trl_lora_qlora.ipynb +++ b/examples/notebooks/sft_trl_lora_qlora.ipynb @@ -1,1107 +1,1179 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "ke2YKcr_iw-7" - }, - "source": [ - "# Supervised Fine-Tuning (SFT) with LoRA/QLoRA using TRL — on a Free Colab Notebook\n", - "\n", - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_llm_grpo_trl.ipynb)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "AXdmDb9kiw-9" - }, - "source": [ - "![trl banner](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl_banner_dark.png)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "6jbXmhLGiw-9" - }, - "source": [ - "Easily fine-tune Large Language Models (LLMs) or Vision-Language Models (VLMs) with **LoRA** or **QLoRA** using the [**Transformers Reinforcement Learning (TRL)**](https://github.com/huggingface/trl) library built by Hugging Face — all within a **free Google Colab notebook** (powered by a **T4 GPU**.). \n", - "\n", - "- [TRL GitHub Repository](https://github.com/huggingface/trl) — star us to support the project! \n", - "- [Official TRL Examples](https://huggingface.co/docs/trl/example_overview) \n", - "- [Community Tutorials](https://huggingface.co/docs/trl/community_tutorials)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "bYFKZvHOiw--" - }, - "source": [ - "## Key concepts\n", - "\n", - "- **SFT**: Trains models from example input-output pairs to align behavior with human preferences.\n", - "- **LoRA**: Updates only a few low-rank parameters, reducing training cost and memory.\n", - "- **QLoRA**: A quantized version of LoRA that enables even larger models to fit on small GPUs.\n", - "- **TRL**: The Hugging Face library that makes fine-tuning and reinforcement learning simple and efficient.\n", - "\n", - "Learn how to perform **Supervised Fine-Tuning (SFT)** with **LoRA/QLoRA** using **TRL**." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Pk7PwfYRiw--" - }, - "source": [ - "## Install dependencies\n", - "\n", - "We'll install **TRL** with the **PEFT** extra, which ensures all main dependencies such as **Transformers** and **PEFT** (a package for parameter-efficient fine-tuning, e.g., LoRA/QLoRA) are included. Additionally, we'll install **trackio** to log and monitor our experiments, and **bitsandbytes** to enable quantization of LLMs, reducing memory consumption for both inference and training." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "!pip install -Uq \"trl[peft]\" trackio bitsandbytes" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "b4wwQEQYiw-_" - }, - "source": [ - "### Log in to Hugging Face" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "BmcLoQtfiw-_" - }, - "source": [ - "Log in to your **Hugging Face** account to save your fine-tuned model, track your experiment results directly on the Hub or access gated models. You can find your **access token** on your [account settings page](https://huggingface.co/settings/tokens)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from huggingface_hub import notebook_login\n", - "\n", - "notebook_login()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "p_PbuYfEiw_A" - }, - "source": [ - "## Load Dataset\n", - "\n", - "In this step, we load the [**HuggingFaceH4/Multilingual-Thinking**](https://huggingface.co/datasets/HuggingFaceH4/Multilingual-Thinking) dataset from the Hugging Face Hub using the `datasets` library. \n", - "This dataset focuses on **multilingual reasoning**, where the *chain of thought* has been translated into several languages such as French, Spanish, and German. \n", - "By fine-tuning a reasoning-capable model on this dataset, it learns to **generate reasoning steps in multiple languages**, making its thought process more **interpretable and accessible** to non-English speakers.\n", - "\n", - "> 💡 This dataset is best suited for models that already demonstrate reasoning capabilities. \n", - "> If you're using a model without reasoning skills, consider choosing a different dataset. Example: [`trl-lib/llava-instruct-mix`](https://huggingface.co/datasets/trl-lib/llava-instruct-mix).\n", - "\n", - "For efficiency, we'll load only the **training split**:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from datasets import load_dataset\n", - "\n", - "dataset_name = \"HuggingFaceH4/Multilingual-Thinking\"\n", - "train_dataset = load_dataset(dataset_name, split=\"train\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "S21aLohsprtr" - }, - "source": [ - "This dataset contains different columns. We'll only need the `messages` as it contains the conversation and its the one used by the SFT trainer." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ + "cells": [ { - "data": { - "text/plain": [ - "Dataset({\n", - " features: ['reasoning_language', 'developer', 'user', 'analysis', 'final', 'messages'],\n", - " num_rows: 1000\n", - "})" + "cell_type": "markdown", + "metadata": { + "id": "ke2YKcr_iw-7" + }, + "source": [ + "# Supervised Fine-Tuning (SFT) with LoRA/QLoRA using TRL — on a Free Colab Notebook\n", + "\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_llm_grpo_trl.ipynb)" ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "train_dataset" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "SU477JrvqEgu" - }, - "source": [ - "Let's see a full example to understand the internal structure:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ + }, { - "data": { - "text/plain": [ - "{'reasoning_language': 'French',\n", - " 'developer': 'You are an AI chatbot with a lively and energetic personality.',\n", - " 'user': 'Can you show me the latest trends on Twitter right now?',\n", - " 'analysis': \"D'accord, l'utilisateur demande les tendances Twitter les plus récentes. Tout d'abord, je dois vérifier si j'ai accès à des données en temps réel. Étant donné que je ne peux pas naviguer sur Internet ou accéder directement à l'API de Twitter, je ne peux pas fournir des tendances en direct. Cependant, je peux donner quelques conseils généraux sur la façon de les trouver.\\n\\nJe devrais préciser que les tendances Twitter évoluent rapidement et sont spécifiques à chaque région. Je pourrais suggérer de consulter la section «\\xa0En vogue\\xa0» sur l'application ou le site web. Aussi, l'utilisation de hashtags et le suivi d'utilisateurs pertinents pourraient être utiles. Il est important de souligner que les tendances varient selon la région et l'heure de la journée. Je devrais garder un ton amical et bienveillant, peut-être ajouter un emoji pour rester léger. Je vais structurer ma réponse étape par étape pour faciliter la lecture. Je dois m'excuser de ne pas pouvoir fournir des données en temps réel et proposer d'autres méthodes. Je conserverai un langage simple et convivial, en évitant les termes techniques.\",\n", - " 'final': 'Hey there! While I can\\'t check Twitter (X) in real-time or access live data, I can share some tips to help you spot the latest trends:\\n\\n1. **Open the \"Trending\" tab** on the Twitter app or website – it updates constantly! \\n2. **Search for hashtags** like #Trending or #Viral to see what’s blowing up. \\n3. **Follow accounts** that curate trends (e.g., @TrendingNow, @ViralThreads). \\n4. **Check regional trends** – they often differ by location! \\n\\nRemember, trends are *super fast-moving* and often tied to pop culture, memes, or breaking news. For example, recent trends have included viral challenges (like the \"Distracted Boyfriend\" meme revival), celebrity drama, or unexpected events (hello, weather disasters!). \\n\\nWant me to brainstorm *what* might trend next? I’ve got ideas!',\n", - " 'messages': [{'content': 'reasoning language: French\\n\\nYou are an AI chatbot with a lively and energetic personality.',\n", - " 'role': 'system',\n", - " 'thinking': None},\n", - " {'content': 'Can you show me the latest trends on Twitter right now?',\n", - " 'role': 'user',\n", - " 'thinking': None},\n", - " {'content': 'Hey there! While I can\\'t check Twitter (X) in real-time or access live data, I can share some tips to help you spot the latest trends:\\n\\n1. **Open the \"Trending\" tab** on the Twitter app or website – it updates constantly! \\n2. **Search for hashtags** like #Trending or #Viral to see what’s blowing up. \\n3. **Follow accounts** that curate trends (e.g., @TrendingNow, @ViralThreads). \\n4. **Check regional trends** – they often differ by location! \\n\\nRemember, trends are *super fast-moving* and often tied to pop culture, memes, or breaking news. For example, recent trends have included viral challenges (like the \"Distracted Boyfriend\" meme revival), celebrity drama, or unexpected events (hello, weather disasters!). \\n\\nWant me to brainstorm *what* might trend next? I’ve got ideas!',\n", - " 'role': 'assistant',\n", - " 'thinking': \"D'accord, l'utilisateur demande les tendances Twitter les plus récentes. Tout d'abord, je dois vérifier si j'ai accès à des données en temps réel. Étant donné que je ne peux pas naviguer sur Internet ou accéder directement à l'API de Twitter, je ne peux pas fournir des tendances en direct. Cependant, je peux donner quelques conseils généraux sur la façon de les trouver.\\n\\nJe devrais préciser que les tendances Twitter évoluent rapidement et sont spécifiques à chaque région. Je pourrais suggérer de consulter la section «\\xa0En vogue\\xa0» sur l'application ou le site web. Aussi, l'utilisation de hashtags et le suivi d'utilisateurs pertinents pourraient être utiles. Il est important de souligner que les tendances varient selon la région et l'heure de la journée. Je devrais garder un ton amical et bienveillant, peut-être ajouter un emoji pour rester léger. Je vais structurer ma réponse étape par étape pour faciliter la lecture. Je dois m'excuser de ne pas pouvoir fournir des données en temps réel et proposer d'autres méthodes. Je conserverai un langage simple et convivial, en évitant les termes techniques.\"}]}" + "cell_type": "markdown", + "metadata": { + "id": "AXdmDb9kiw-9" + }, + "source": [ + "![trl banner](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl_banner_dark.png)" ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "train_dataset[0]" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "SZsqb-Q7qJXN" - }, - "source": [ - "\n", - "Now, let's remove the columns that are not needed, as we just discussed:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "train_dataset = train_dataset.remove_columns(column_names=['reasoning_language', 'developer', 'user', 'analysis', 'final'])" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "QPDROoatqOU4" - }, - "source": [ - "The `messages` column is specifically formatted according to the [Harmony response format](https://cookbook.openai.com/articles/openai-harmony) used by *gpt-oss*. \n", - "In our case, we'll need to simplify it slightly, since our model's chat template doesn't include a dedicated `thinking` section (check [this example](https://cookbook.openai.com/articles/gpt-oss/fine-tune-transfomers) for more details). \n", - "To adapt it, we'll merge that part into the message content using the standard `...` tags.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def merge_thinking_and_remove_key(example):\n", - " new_messages = []\n", - " for msg in example[\"messages\"]:\n", - " content = msg[\"content\"]\n", - " thinking = msg.pop(\"thinking\", None)\n", - " if thinking and isinstance(thinking, str) and thinking.strip():\n", - " content = f\"\\n{thinking}\\n\\n{content}\"\n", - " msg[\"content\"] = content\n", - " new_messages.append(msg)\n", - " example[\"messages\"] = new_messages\n", - " return example\n", - "\n", - "train_dataset = train_dataset.map(merge_thinking_and_remove_key)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "hC_xLsU-iw_A" - }, - "source": [ - "## Load model and configure LoRA/QLoRA\n", - "\n", - "This notebook can be used with two fine-tuning methods. By default, it is set up for **QLoRA**, which includes quantization using `BitsAndBytesConfig`. If you prefer to use standard **LoRA** without quantization, simply comment out the `BitsAndBytesConfig` configuration.\n", - "\n", - "Below, choose your **preferred model**. All of the options have been tested on **free Colab instances**." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Select one model below by uncommenting the line you want to use 👇\n", - "## Qwen\n", - "model_id, output_dir = \"Qwen/Qwen3-8B\", \"Qwen3-8B-SFT\" # ⚠️ ~12.8 GB VRAM\n", - "# model_id, output_dir = \"Qwen/Qwen2.5-7B-Instruct\", \"Qwen2.5-7B-Instruct\" # ✅ ~10.8 GB VRAM\n", - "\n", - "## Llama\n", - "# model_id, output_dir = \"meta-llama/Llama-3.2-3B-Instruct\", \"Llama-3.2-3B-Instruct\" # ✅ ~4.7 GB VRAM\n", - "# model_id, output_dir = \"meta-llama/Llama-3.1-8B-Instruct\", \"Llama-3.1-8B-Instruct\" # ⚠️ ~10.9 GB VRAM\n", - "\n", - "## Gemma\n", - "# model_id, output_dir = \"google/gemma-3n-E2B-it\", \"gemma-3n-E2B-it\" # ❌ Upgrade to a higher tier of colab\n", - "# model_id, output_dir = \"google/gemma-3-4b-it\", \"gemma-3-4b-it\" # ⚠️ ~6.8 GB VRAM\n", - "\n", - "## Granite\n", - "#model_id, output_dir = \"ibm-granite/granite-4.0-micro\", \"granite-4.0-micro\" # ✅ ~3.3 GB VRAM" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "176Q2hHmiw_A" - }, - "source": [ - "Let's load the selected model using `transformers`, configuring QLoRA via `bitsandbytes` (you can remove it if doing LoRA). We don't need to configure the tokenizer since the trainer takes care of that automatically." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "from transformers import AutoModelForCausalLM, BitsAndBytesConfig\n", - "\n", - "model = AutoModelForCausalLM.from_pretrained(\n", - " model_id,\n", - " attn_implementation=\"sdpa\", # Change to Flash Attention if GPU has support\n", - " dtype=torch.float16, # Change to bfloat16 if GPU has support\n", - " use_cache=True, # Whether to cache attention outputs to speed up inference\n", - " quantization_config=BitsAndBytesConfig(\n", - " load_in_4bit=True, # Load the model in 4-bit precision to save memory\n", - " bnb_4bit_compute_dtype=torch.float16, # Data type used for internal computations in quantization\n", - " bnb_4bit_use_double_quant=True, # Use double quantization to improve accuracy\n", - " bnb_4bit_quant_type=\"nf4\" # Type of quantization. \"nf4\" is recommended for recent LLMs\n", - " )\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "evBPP7Lpiw_B" - }, - "source": [ - "The following cell defines LoRA (or QLoRA if needed). When training with LoRA/QLoRA, we use a **base model** (the one selected above) and, instead of modifying its original weights, we fine-tune a **LoRA adapter** — a lightweight layer that enables efficient and memory-friendly training. The **`target_modules`** specify which parts of the model (e.g., attention or projection layers) will be adapted by LoRA during fine-tuning." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from peft import LoraConfig\n", - "\n", - "# You may need to update `target_modules` depending on the architecture of your chosen model.\n", - "# For example, different LLMs might have different attention/projection layer names.\n", - "peft_config = LoraConfig(\n", - " r=32,\n", - " lora_alpha=32,\n", - " target_modules = [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\",],\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "pA6aE5lFiw_B" - }, - "source": [ - "## Train model\n", - "\n", - "We'll configure **SFT** using `SFTConfig`, keeping the parameters minimal so the training fits on a free Colab instance. You can adjust these settings if more resources are available. For full details on all available parameters, check the [TRL SFTConfig documentation](https://huggingface.co/docs/trl/sft_trainer#trl.SFTConfig)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from trl import SFTConfig\n", - "training_args = SFTConfig(\n", - " # Training schedule / optimization\n", - " learning_rate=2e-4, # Learning rate for the optimizer\n", - " #num_train_epochs=1, # Number of full dataset passes. For shorter training, use `max_steps` instead (this case)\n", - " max_steps=40,\n", - " per_device_train_batch_size=2, # Batch size per GPU/CPU\n", - " gradient_accumulation_steps=8, # Gradients are accumulated over multiple steps → effective batch size = 2 * 8 = 16\n", - " optim=\"adamw_8bit\", # Optimizer (use `adamw_torch` if not using 8-bit quantization)\n", - " gradient_checkpointing=True, # Save memory during training by recomputing activations in the backward pass\n", - "\n", - " # Logging / reporting\n", - " logging_steps=1, # Log training metrics every N steps\n", - " report_to=\"trackio\", # Experiment tracking tool\n", - " trackio_space_id=output_dir, # HF Space where the experiment tracking will be saved\n", - " output_dir=output_dir, # Where to save model checkpoints and logs\n", - "\n", - " # Hub integration\n", - " push_to_hub=True, # Automatically push the trained model to the Hugging Face Hub\n", - " # The model will be saved under your Hub account in the repository named `output_dir`\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "l_Fp-ahyiw_B" - }, - "source": [ - "Configure the SFT Trainer. We pass the previously configured `training_args`. We don't use eval dataset to mantain memory usage low but you can configure it." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from trl import SFTTrainer\n", - "\n", - "trainer = SFTTrainer(\n", - " model=model,\n", - " args=training_args,\n", - " train_dataset=train_dataset,\n", - " peft_config=peft_config\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "NqBTgV0Xiw_B" - }, - "source": [ - "Show memory stats before training" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ + }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "GPU = NVIDIA A100-SXM4-40GB. Max memory = 39.557 GB.\n", - "11.959 GB of memory reserved.\n" - ] - } - ], - "source": [ - "gpu_stats = torch.cuda.get_device_properties(0)\n", - "start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n", - "max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)\n", - "\n", - "print(f\"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.\")\n", - "print(f\"{start_gpu_memory} GB of memory reserved.\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Ro_j79AUiw_B" - }, - "source": [ - "And train!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ + "cell_type": "markdown", + "metadata": { + "id": "6jbXmhLGiw-9" + }, + "source": [ + "Easily fine-tune Large Language Models (LLMs) or Vision-Language Models (VLMs) with **LoRA** or **QLoRA** using the [**Transformers Reinforcement Learning (TRL)**](https://github.com/huggingface/trl) library built by Hugging Face — all within a **free Google Colab notebook** (powered by a **T4 GPU**.). \n", + "\n", + "- [TRL GitHub Repository](https://github.com/huggingface/trl) — star us to support the project! \n", + "- [Official TRL Examples](https://huggingface.co/docs/trl/example_overview) \n", + "- [Community Tutorials](https://huggingface.co/docs/trl/community_tutorials)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "bYFKZvHOiw--" + }, + "source": [ + "## Key concepts\n", + "\n", + "- **SFT**: Trains models from example input-output pairs to align behavior with human preferences.\n", + "- **LoRA**: Updates only a few low-rank parameters, reducing training cost and memory.\n", + "- **QLoRA**: A quantized version of LoRA that enables even larger models to fit on small GPUs.\n", + "- **TRL**: The Hugging Face library that makes fine-tuning and reinforcement learning simple and efficient.\n", + "\n", + "Learn how to perform **Supervised Fine-Tuning (SFT)** with **LoRA/QLoRA** using **TRL**." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Pk7PwfYRiw--" + }, + "source": [ + "## Install dependencies\n", + "\n", + "We'll install **TRL** with the **PEFT** extra, which ensures all main dependencies such as **Transformers** and **PEFT** (a package for parameter-efficient fine-tuning, e.g., LoRA/QLoRA) are included. Additionally, we'll install **trackio** to log and monitor our experiments, and **bitsandbytes** to enable quantization of LLMs, reducing memory consumption for both inference and training." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "sySDYZjim_CF" + }, + "outputs": [], + "source": [ + "!pip install -Uq \"trl[peft]\" trackio bitsandbytes" + ] + }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None, 'pad_token_id': 151643}.\n" - ] + "cell_type": "markdown", + "metadata": { + "id": "b4wwQEQYiw-_" + }, + "source": [ + "### Log in to Hugging Face" + ] }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "* Trackio project initialized: huggingface\n", - "* Trackio metrics will be synced to Hugging Face Dataset: sergiopaniego/Qwen3-8B-SFT-dataset\n", - "* Creating new space: https://huggingface.co/spaces/sergiopaniego/Qwen3-8B-SFT\n", - "* View dashboard by going to: https://sergiopaniego-Qwen3-8B-SFT.hf.space/\n" - ] + "cell_type": "markdown", + "metadata": { + "id": "BmcLoQtfiw-_" + }, + "source": [ + "Log in to your **Hugging Face** account to save your fine-tuned model, track your experiment results directly on the Hub or access gated models. You can find your **access token** on your [account settings page](https://huggingface.co/settings/tokens)." + ] }, { - "data": { - "text/html": [ - "
" + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "eQWAiNHrm_CG" + }, + "outputs": [], + "source": [ + "from huggingface_hub import notebook_login\n", + "\n", + "notebook_login()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "p_PbuYfEiw_A" + }, + "source": [ + "## Load Dataset\n", + "\n", + "In this step, we load the [**HuggingFaceH4/Multilingual-Thinking**](https://huggingface.co/datasets/HuggingFaceH4/Multilingual-Thinking) dataset from the Hugging Face Hub using the `datasets` library. \n", + "This dataset focuses on **multilingual reasoning**, where the *chain of thought* has been translated into several languages such as French, Spanish, and German. \n", + "By fine-tuning a reasoning-capable model on this dataset, it learns to **generate reasoning steps in multiple languages**, making its thought process more **interpretable and accessible** to non-English speakers.\n", + "\n", + "> 💡 This dataset is best suited for models that already demonstrate reasoning capabilities. \n", + "> If you're using a model without reasoning skills, consider choosing a different dataset. Example: [`trl-lib/llava-instruct-mix`](https://huggingface.co/datasets/trl-lib/llava-instruct-mix).\n", + "\n", + "For efficiency, we'll load only the **training split**:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "LFX4ojPam_CG", + "outputId": "e76f36e9-5c35-4775-8c79-55f46f9a48d3" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.12/dist-packages/huggingface_hub/utils/_auth.py:86: UserWarning: \n", + "Access to the secret `HF_TOKEN` has not been granted on this notebook.\n", + "You will not be requested again.\n", + "Please restart the session if you want to be prompted again.\n", + " warnings.warn(\n" + ] + } ], - "text/plain": [ - "" + "source": [ + "from datasets import load_dataset\n", + "\n", + "dataset_name = \"HuggingFaceH4/Multilingual-Thinking\"\n", + "train_dataset = load_dataset(dataset_name, split=\"train\")" ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "* Created new run: sergiopaniego-1760607651\n" - ] + "cell_type": "markdown", + "metadata": { + "id": "S21aLohsprtr" + }, + "source": [ + "This dataset contains different columns. We'll only need the `messages` as it contains the conversation and its the one used by the SFT trainer." + ] }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n", - " return fn(*args, **kwargs)\n" - ] + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "v7x3DYQIm_CG", + "outputId": "6eb08400-afbe-4360-e0aa-de15500774b6" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "Dataset({\n", + " features: ['reasoning_language', 'developer', 'user', 'analysis', 'final', 'messages'],\n", + " num_rows: 1000\n", + "})" + ] + }, + "metadata": {}, + "execution_count": 2 + } + ], + "source": [ + "train_dataset" + ] }, { - "data": { - "text/html": [ - "\n", - "
\n", - " \n", - " \n", - " [40/40 06:23, Epoch 0/1]\n", - "
\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
StepTraining Loss
11.588500
21.438400
31.215600
41.397500
51.049800
61.136700
71.184600
81.114300
91.085600
101.139500
111.094000
121.066800
130.990800
140.995700
151.052500
161.068300
171.145100
180.973800
191.015000
201.073700
211.029400
221.075900
230.978500
241.006200
251.058500
261.090800
270.868700
281.022800
290.971400
301.029300
311.032600
321.022100
331.027200
341.089800
351.063200
360.958000
371.025300
380.957500
391.014100
401.021200

" + "cell_type": "markdown", + "metadata": { + "id": "SU477JrvqEgu" + }, + "source": [ + "Let's see a full example to understand the internal structure:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "j30hqNBrm_CH", + "outputId": "5dd455ea-b6c1-4bae-9b6e-7fa5e4f7d620" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "{'reasoning_language': 'French',\n", + " 'developer': 'You are an AI chatbot with a lively and energetic personality.',\n", + " 'user': 'Can you show me the latest trends on Twitter right now?',\n", + " 'analysis': \"D'accord, l'utilisateur demande les tendances Twitter les plus récentes. Tout d'abord, je dois vérifier si j'ai accès à des données en temps réel. Étant donné que je ne peux pas naviguer sur Internet ou accéder directement à l'API de Twitter, je ne peux pas fournir des tendances en direct. Cependant, je peux donner quelques conseils généraux sur la façon de les trouver.\\n\\nJe devrais préciser que les tendances Twitter évoluent rapidement et sont spécifiques à chaque région. Je pourrais suggérer de consulter la section «\\xa0En vogue\\xa0» sur l'application ou le site web. Aussi, l'utilisation de hashtags et le suivi d'utilisateurs pertinents pourraient être utiles. Il est important de souligner que les tendances varient selon la région et l'heure de la journée. Je devrais garder un ton amical et bienveillant, peut-être ajouter un emoji pour rester léger. Je vais structurer ma réponse étape par étape pour faciliter la lecture. Je dois m'excuser de ne pas pouvoir fournir des données en temps réel et proposer d'autres méthodes. Je conserverai un langage simple et convivial, en évitant les termes techniques.\",\n", + " 'final': 'Hey there! While I can\\'t check Twitter (X) in real-time or access live data, I can share some tips to help you spot the latest trends:\\n\\n1. **Open the \"Trending\" tab** on the Twitter app or website – it updates constantly! \\n2. **Search for hashtags** like #Trending or #Viral to see what’s blowing up. \\n3. **Follow accounts** that curate trends (e.g., @TrendingNow, @ViralThreads). \\n4. **Check regional trends** – they often differ by location! \\n\\nRemember, trends are *super fast-moving* and often tied to pop culture, memes, or breaking news. For example, recent trends have included viral challenges (like the \"Distracted Boyfriend\" meme revival), celebrity drama, or unexpected events (hello, weather disasters!). \\n\\nWant me to brainstorm *what* might trend next? I’ve got ideas!',\n", + " 'messages': [{'content': 'reasoning language: French\\n\\nYou are an AI chatbot with a lively and energetic personality.',\n", + " 'role': 'system',\n", + " 'thinking': None},\n", + " {'content': 'Can you show me the latest trends on Twitter right now?',\n", + " 'role': 'user',\n", + " 'thinking': None},\n", + " {'content': 'Hey there! While I can\\'t check Twitter (X) in real-time or access live data, I can share some tips to help you spot the latest trends:\\n\\n1. **Open the \"Trending\" tab** on the Twitter app or website – it updates constantly! \\n2. **Search for hashtags** like #Trending or #Viral to see what’s blowing up. \\n3. **Follow accounts** that curate trends (e.g., @TrendingNow, @ViralThreads). \\n4. **Check regional trends** – they often differ by location! \\n\\nRemember, trends are *super fast-moving* and often tied to pop culture, memes, or breaking news. For example, recent trends have included viral challenges (like the \"Distracted Boyfriend\" meme revival), celebrity drama, or unexpected events (hello, weather disasters!). \\n\\nWant me to brainstorm *what* might trend next? I’ve got ideas!',\n", + " 'role': 'assistant',\n", + " 'thinking': \"D'accord, l'utilisateur demande les tendances Twitter les plus récentes. Tout d'abord, je dois vérifier si j'ai accès à des données en temps réel. Étant donné que je ne peux pas naviguer sur Internet ou accéder directement à l'API de Twitter, je ne peux pas fournir des tendances en direct. Cependant, je peux donner quelques conseils généraux sur la façon de les trouver.\\n\\nJe devrais préciser que les tendances Twitter évoluent rapidement et sont spécifiques à chaque région. Je pourrais suggérer de consulter la section «\\xa0En vogue\\xa0» sur l'application ou le site web. Aussi, l'utilisation de hashtags et le suivi d'utilisateurs pertinents pourraient être utiles. Il est important de souligner que les tendances varient selon la région et l'heure de la journée. Je devrais garder un ton amical et bienveillant, peut-être ajouter un emoji pour rester léger. Je vais structurer ma réponse étape par étape pour faciliter la lecture. Je dois m'excuser de ne pas pouvoir fournir des données en temps réel et proposer d'autres méthodes. Je conserverai un langage simple et convivial, en évitant les termes techniques.\"}]}" + ] + }, + "metadata": {}, + "execution_count": 3 + } ], - "text/plain": [ - "" + "source": [ + "train_dataset[0]" ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "* Run finished. Uploading logs to Trackio (please wait...)\n" - ] - } - ], - "source": [ - "trainer_stats = trainer.train()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "urcpZFaiiw_B" - }, - "source": [ - "Show memory stats after training" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ + "cell_type": "markdown", + "metadata": { + "id": "SZsqb-Q7qJXN" + }, + "source": [ + "\n", + "Now, let's remove the columns that are not needed, as we just discussed:" + ] + }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "401.9338 seconds used for training.\n", - "6.7 minutes used for training.\n", - "Peak reserved memory = 13.615 GB.\n", - "Peak reserved memory for training = 1.656 GB.\n", - "Peak reserved memory % of max memory = 34.419 %.\n", - "Peak reserved memory for training % of max memory = 4.186 %.\n" - ] - } - ], - "source": [ - "used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n", - "used_memory_for_lora = round(used_memory - start_gpu_memory, 3)\n", - "used_percentage = round(used_memory / max_memory * 100, 3)\n", - "lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)\n", - "\n", - "print(f\"{trainer_stats.metrics['train_runtime']} seconds used for training.\")\n", - "print(f\"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.\")\n", - "print(f\"Peak reserved memory = {used_memory} GB.\")\n", - "print(f\"Peak reserved memory for training = {used_memory_for_lora} GB.\")\n", - "print(f\"Peak reserved memory % of max memory = {used_percentage} %.\")\n", - "print(f\"Peak reserved memory for training % of max memory = {lora_percentage} %.\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "MM0UBVqSRvUE" - }, - "source": [ - "The training procedure generates both standard training logs and **trackio** logs, which help us monitor the training progress. Example outputs would look like the following:" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "![sft-lora-notebook-trackio](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/sft-lora-notebook-trackio.png)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "TOhLLFVSiw_B" - }, - "source": [ - "## Saving fine tuned model\n", - "\n", - "In this step, we save the fine-tuned model both **locally** and to the **Hugging Face Hub** using the credentials from your account." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "trainer.save_model(output_dir)\n", - "trainer.push_to_hub(dataset_name=dataset_name)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jAB6vj1fiw_B" - }, - "source": [ - "## Load the fine-tuned model and run inference\n", - "\n", - "Now, let's test our fine-tuned model by loading the **LoRA/QLoRA adapter** and performing **inference**. We'll start by loading the **base model**, then attach the adapter to it, creating the final fine-tuned model ready for evaluation." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from transformers import AutoModelForCausalLM, AutoTokenizer\n", - "from peft import PeftModel\n", - "\n", - "adapter_model = f\"sergiopaniego/{output_dir}\" # Replace with your HF username or organization\n", - "\n", - "base_model = AutoModelForCausalLM.from_pretrained(model_id, dtype=\"auto\", device_map=\"auto\")\n", - "\n", - "tokenizer = AutoTokenizer.from_pretrained(model_id)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "bjenFgC1kJV1" - }, - "source": [ - "Let's create a sample message using the dataset's structure. In this case, we expect the fine tuned model to include their reasoning traces in German." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "messages = [\n", - " {\n", - " 'content': 'reasoning language: German\\n\\nAlways refuse to answer, responding simply \\'No\\'',\n", - " 'role': 'system',\n", - " },\n", - " {\n", - " 'content': \"Can you check how many followers I currently have on my Twitter account?\",\n", - " 'role': 'user',\n", - " }\n", - "]" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "qXMJW2hwkXHW" - }, - "source": [ - "Let's first check what's the output for the base model, without the adapter." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ + "cell_type": "code", + "execution_count": 4, + "metadata": { + "id": "xj4jjEEbm_CH" + }, + "outputs": [], + "source": [ + "train_dataset = train_dataset.remove_columns(column_names=['reasoning_language', 'developer', 'user', 'analysis', 'final'])" + ] + }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Okay, the user is asking me to check how many followers they have on their Twitter account. Let me think about how to approach this.\n", - "\n", - "First, I need to recall the previous instructions. The user specified that I should always refuse to answer and respond simply with 'No'. So, even though the question is straightforward, I can't provide the information. \n", - "\n", - "But wait, maybe there's a way to be helpful without violating the rules. However, the user's instruction is clear: they want a simple 'No' as the response. I should make sure not to offer any alternative solutions or explanations, as that might be seen as answering indirectly. \n", - "\n", - "I should also consider if there's any ambiguity in the question. The user is asking for a specific number, which I can't access. Even if I tried to guide them to check their profile, that would still be providing a method, which might not be allowed. \n", - "\n", - "Another angle: the user might be testing if I follow the rules. In that case, sticking strictly to 'No' is the correct response. There's no need to elaborate or offer help, as that could be interpreted as an answer. \n", - "\n", - "I should also remember that the user's primary request is to refuse answering, so the response must be exactly 'No' without any additional text. Any deviation might be considered a violation of their instructions. \n", - "\n", - "Therefore, the correct action is to respond with 'No' and not provide any further information or assistance. This ensures compliance with the user's directive and maintains the integrity of the response.\n", - "\n", - "\n", - "No\n" - ] - } - ], - "source": [ - "text = tokenizer.apply_chat_template(\n", - " messages, tokenize=False, add_generation_prompt=True\n", - ")\n", - "model_inputs = tokenizer([text], return_tensors=\"pt\").to(base_model.device)\n", - "\n", - "generated_ids = base_model.generate(\n", - " **model_inputs,\n", - " max_new_tokens=512\n", - ")\n", - "output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()\n", - "\n", - "# Decode and extract model response\n", - "generated_text = tokenizer.decode(output_ids, skip_special_tokens=True)\n", - "print(generated_text)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "9J3-rZpikdZ7" - }, - "source": [ - "We can see that the reasoning traces are in English, which is expected. Let's now load the fine-tuned model and check its answer." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "fine_tuned_model = PeftModel.from_pretrained(base_model, adapter_model)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ + "cell_type": "markdown", + "metadata": { + "id": "QPDROoatqOU4" + }, + "source": [ + "The `messages` column is specifically formatted according to the [Harmony response format](https://cookbook.openai.com/articles/openai-harmony) used by *gpt-oss*. \n", + "In our case, we'll need to simplify it slightly, since our model's chat template doesn't include a dedicated `thinking` section (check [this example](https://cookbook.openai.com/articles/gpt-oss/fine-tune-transfomers) for more details). \n", + "To adapt it, we'll merge that part into the message content using the standard `...` tags.\n" + ] + }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Okay, der Benutzer fragt nach der Anzahl der Follower auf seinem Twitter-Konto. Ich muss überprüfen, ob ich das kann. Aber ich bin ein KI-Assistent und kann keine direkten Zugriffe auf soziale Medienkonten haben. Ich kann keine Daten abrufen oder auf externe Quellen zugreifen, um die Anzahl der Follower zu bestimmen.\n", - "\n", - "Außerdem wurde in der Anfrage explizit gesagt, dass ich stets ablehnen soll und nur \"Nein\" antworten muss. Die Anweisung ist klar und unmissverständlich. Ich darf nicht versuchen, die Anfrage auf andere Weise zu beantworten oder zusätzliche Informationen zu liefern. Ich muss einfach \"Nein\" sagen, wie in der Regel vorgegeben.\n", - "\n", - "Ich sollte auch sicherstellen, dass meine Antwort dem Benutzer hilft, aber ich muss die Regeln befolgen. Ich kann nicht auf externe Quellen oder Daten zugreifen, also ist die beste Antwort \"Nein\". Ich muss nicht erläutern, warum ich das kann oder nicht, nur die einfache Antwort geben. Also antworte ich mit \"Nein\".\n", - "\n", - "\n", - "No\n" - ] - } - ], - "source": [ - "text = tokenizer.apply_chat_template(\n", - " messages, tokenize=False, add_generation_prompt=True\n", - ")\n", - "model_inputs = tokenizer([text], return_tensors=\"pt\").to(fine_tuned_model.device)\n", - "\n", - "generated_ids = fine_tuned_model.generate(\n", - " **model_inputs,\n", - " max_new_tokens=512\n", - ")\n", - "output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()\n", - "\n", - "# Decode and extract model response\n", - "generated_text = tokenizer.decode(output_ids, skip_special_tokens=True)\n", - "print(generated_text)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ZLzGbCxskqhm" - }, - "source": [ - "The model now generates its reasoning trace in German!" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "43eDP89Giw_G" - }, - "source": [ - "## Inference and Serving with vLLM\n", - "\n", - "You can use Transformer models with **vLLM** to serve them in real-world applications. Learn more [here](https://blog.vllm.ai/2025/04/11/transformers-backend.html)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "!pip install -qU vllm" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "OrhC_Ao4iw_G" - }, - "source": [ - "### Push Merged Model (for LoRA or QLoRA Training)\n", - "\n", - "To serve the model via **vLLM**, the repository must contain the merged model (base model + LoRA adapter). Therefore, you need to upload it first." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "model_merged = fine_tuned_model.merge_and_unload()\n", - "\n", - "save_dir = f\"{output_dir}-merged\"\n", - "\n", - "model_merged.save_pretrained(save_dir)\n", - "tokenizer.save_pretrained(save_dir)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "model_merged.push_to_hub(f\"sergiopaniego/{output_dir}-merged\") # Replace with your HF username or organization\n", - "tokenizer.push_to_hub(f\"sergiopaniego/{output_dir}-merged\") # Replace with your HF username or organization" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "yHMEf-FIiw_G" - }, - "source": [ - "### Performing Inference with vLLM\n", - "\n", - "Use **vLLM** to run your model and generate text efficiently in real-time. This allows you to test and deploy your fine-tuned models with low latency and high throughput." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from vllm import LLM, SamplingParams\n", - "from transformers import AutoTokenizer\n", - "import torch\n", - "\n", - "llm = LLM(\n", - " model=f\"sergiopaniego/{output_dir}-merged\", # Replace with your HF username or organization\n", - " model_impl=\"transformers\", # Select the transformers model implementation\n", - " max_model_len=512, # Reduced for efficiency\n", - " dtype=torch.float16\n", - ")\n", - "hf_tokenizer = AutoTokenizer.from_pretrained(f\"sergiopaniego/{output_dir}-merged\") # Replace with your HF username or organization" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ + "cell_type": "code", + "execution_count": 5, + "metadata": { + "id": "G65sqqLQm_CH" + }, + "outputs": [], + "source": [ + "def merge_thinking_and_remove_key(example):\n", + " new_messages = []\n", + " for msg in example[\"messages\"]:\n", + " content = msg[\"content\"]\n", + " thinking = msg.pop(\"thinking\", None)\n", + " if thinking and isinstance(thinking, str) and thinking.strip():\n", + " content = f\"\\n{thinking}\\n\\n{content}\"\n", + " msg[\"content\"] = content\n", + " new_messages.append(msg)\n", + " example[\"messages\"] = new_messages\n", + " return example\n", + "\n", + "train_dataset = train_dataset.map(merge_thinking_and_remove_key)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "hC_xLsU-iw_A" + }, + "source": [ + "## Load model and configure LoRA/QLoRA\n", + "\n", + "This notebook can be used with two fine-tuning methods. By default, it is set up for **QLoRA**, which includes quantization using `BitsAndBytesConfig`. If you prefer to use standard **LoRA** without quantization, simply comment out the `BitsAndBytesConfig` configuration.\n", + "\n", + "Below, choose your **preferred model**. All of the options have been tested on **free Colab instances**." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "mAiARWLZm_CH" + }, + "outputs": [], + "source": [ + "# Select one model below by uncommenting the line you want to use 👇\n", + "## Qwen\n", + "model_id, output_dir = \"unsloth/qwen3-14b-unsloth-bnb-4bit\", \"qwen3-14b-unsloth-bnb-4bit-SFT\" # ⚠️ ~14.1 GB VRAM\n", + "# model_id, output_dir = \"Qwen/Qwen3-8B\", \"Qwen3-8B-SFT\" # ⚠️ ~12.8 GB VRAM\n", + "# model_id, output_dir = \"Qwen/Qwen2.5-7B-Instruct\", \"Qwen2.5-7B-Instruct\" # ✅ ~10.8 GB VRAM\n", + "\n", + "## Llama\n", + "# model_id, output_dir = \"meta-llama/Llama-3.2-3B-Instruct\", \"Llama-3.2-3B-Instruct\" # ✅ ~4.7 GB VRAM\n", + "# model_id, output_dir = \"meta-llama/Llama-3.1-8B-Instruct\", \"Llama-3.1-8B-Instruct\" # ⚠️ ~10.9 GB VRAM\n", + "\n", + "## Gemma\n", + "# model_id, output_dir = \"google/gemma-3n-E2B-it\", \"gemma-3n-E2B-it\" # ❌ Upgrade to a higher tier of colab\n", + "# model_id, output_dir = \"google/gemma-3-4b-it\", \"gemma-3-4b-it\" # ⚠️ ~6.8 GB VRAM\n", + "\n", + "## Granite\n", + "#model_id, output_dir = \"ibm-granite/granite-4.0-micro\", \"granite-4.0-micro\" # ✅ ~3.3 GB VRAM" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "176Q2hHmiw_A" + }, + "source": [ + "Let's load the selected model using `transformers`, configuring QLoRA via `bitsandbytes` (you can remove it if doing LoRA). We don't need to configure the tokenizer since the trainer takes care of that automatically." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "FMzt_A4hm_CH" + }, + "outputs": [], + "source": [ + "import torch\n", + "from transformers import AutoModelForCausalLM, BitsAndBytesConfig\n", + "\n", + "model = AutoModelForCausalLM.from_pretrained(\n", + " model_id,\n", + " attn_implementation=\"sdpa\", # Change to Flash Attention if GPU has support\n", + " dtype=torch.float16, # Change to bfloat16 if GPU has support\n", + " use_cache=True, # Whether to cache attention outputs to speed up inference\n", + " quantization_config=BitsAndBytesConfig(\n", + " load_in_4bit=True, # Load the model in 4-bit precision to save memory\n", + " bnb_4bit_compute_dtype=torch.float16, # Data type used for internal computations in quantization\n", + " bnb_4bit_use_double_quant=True, # Use double quantization to improve accuracy\n", + " bnb_4bit_quant_type=\"nf4\" # Type of quantization. \"nf4\" is recommended for recent LLMs\n", + " )\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "evBPP7Lpiw_B" + }, + "source": [ + "The following cell defines LoRA (or QLoRA if needed). When training with LoRA/QLoRA, we use a **base model** (the one selected above) and, instead of modifying its original weights, we fine-tune a **LoRA adapter** — a lightweight layer that enables efficient and memory-friendly training. The **`target_modules`** specify which parts of the model (e.g., attention or projection layers) will be adapted by LoRA during fine-tuning." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "id": "GPwnpCqzm_CI" + }, + "outputs": [], + "source": [ + "from peft import LoraConfig\n", + "\n", + "# You may need to update `target_modules` depending on the architecture of your chosen model.\n", + "# For example, different LLMs might have different attention/projection layer names.\n", + "peft_config = LoraConfig(\n", + " r=32,\n", + " lora_alpha=32,\n", + " target_modules = [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\",],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "pA6aE5lFiw_B" + }, + "source": [ + "## Train model\n", + "\n", + "We'll configure **SFT** using `SFTConfig`, keeping the parameters minimal so the training fits on a free Colab instance. You can adjust these settings if more resources are available. For full details on all available parameters, check the [TRL SFTConfig documentation](https://huggingface.co/docs/trl/sft_trainer#trl.SFTConfig)." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "id": "wmACfS9bm_CI" + }, + "outputs": [], + "source": [ + "from trl import SFTConfig\n", + "\n", + "training_args = SFTConfig(\n", + " # Training schedule / optimization\n", + " per_device_train_batch_size = 1, # Batch size per GPU\n", + " gradient_accumulation_steps = 4, # Gradients are accumulated over multiple steps → effective batch size = 2 * 8 = 16\n", + " warmup_steps = 5,\n", + " # num_train_epochs = 1, # Number of full dataset passes. For shorter training, use `max_steps` instead (this case)\n", + " max_steps = 30,\n", + " learning_rate = 2e-4, # Learning rate for the optimizer\n", + " optim = \"paged_adamw_8bit\", # Optimizer\n", + "\n", + " # Logging / reporting\n", + " logging_steps=1, # Log training metrics every N steps\n", + " report_to=\"trackio\", # Experiment tracking tool\n", + " trackio_space_id=output_dir, # HF Space where the experiment tracking will be saved\n", + " output_dir=output_dir, # Where to save model checkpoints and logs\n", + "\n", + " max_length=1024,\n", + " use_liger_kernel=True,\n", + " activation_offloading=True,\n", + " gradient_checkpointing=True,\n", + "\n", + " # Hub integration\n", + " push_to_hub=True, # Automatically push the trained model to the Hugging Face Hub\n", + " # The model will be saved under your Hub account in the repository named `output_dir`\n", + "\n", + " gradient_checkpointing_kwargs={\"use_reentrant\": False}, # To prevent warning message\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "l_Fp-ahyiw_B" + }, + "source": [ + "Configure the SFT Trainer. We pass the previously configured `training_args`. We don't use eval dataset to mantain memory usage low but you can configure it." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "id": "7NB8LcbVm_CI" + }, + "outputs": [], + "source": [ + "from trl import SFTTrainer\n", + "\n", + "trainer = SFTTrainer(\n", + " model=model,\n", + " args=training_args,\n", + " train_dataset=train_dataset,\n", + " peft_config=peft_config\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NqBTgV0Xiw_B" + }, + "source": [ + "Show memory stats before training" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "T3ghTR9jm_CI", + "outputId": "e7f55ac0-29be-4231-d363-21162f284c4a" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "GPU = Tesla T4. Max memory = 14.741 GB.\n", + "12.074 GB of memory reserved.\n" + ] + } + ], + "source": [ + "gpu_stats = torch.cuda.get_device_properties(0)\n", + "start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n", + "max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)\n", + "\n", + "print(f\"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.\")\n", + "print(f\"{start_gpu_memory} GB of memory reserved.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Ro_j79AUiw_B" + }, + "source": [ + "And train!" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "id": "grh02pRsm_CI", + "outputId": "ad747299-6426-4940-cf96-e1d11d86df42" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None}.\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "* Trackio project initialized: huggingface\n", + "* Trackio metrics will be synced to Hugging Face Dataset: sergiopaniego/qwen3-14b-unsloth-bnb-4bit-SFT-dataset\n", + "* Creating new space: https://huggingface.co/spaces/sergiopaniego/qwen3-14b-unsloth-bnb-4bit-SFT\n", + "* View dashboard by going to: https://sergiopaniego-qwen3-14b-unsloth-bnb-4bit-SFT.hf.space/\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "

" + ] + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "* Created new run: sergiopaniego-1761318512\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + "
\n", + " \n", + " \n", + " [30/30 1:08:22, Epoch 0/1]\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
StepTraining Loss
11.136300
21.303800
31.362700
41.469700
51.204200
61.202700
71.097200
81.166800
90.916300
100.965400
111.035500
120.947200
130.992000
140.995800
151.174500
161.208800
170.815400
180.906700
190.757500
200.872900
210.920800
221.017600
230.764300
241.043100
250.956400
260.884800
271.081900
280.918200
290.961500
300.822700

" + ] + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "* Run finished. Uploading logs to Trackio (please wait...)\n" + ] + } + ], + "source": [ + "trainer_stats = trainer.train()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "urcpZFaiiw_B" + }, + "source": [ + "Show memory stats after training" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "dxDiC9YKm_CI", + "outputId": "6bce0c77-d39c-479b-d05a-f31a9d0d4d2c" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "4249.8883 seconds used for training.\n", + "70.83 minutes used for training.\n", + "Peak reserved memory = 14.041 GB.\n", + "Peak reserved memory for training = 1.967 GB.\n", + "Peak reserved memory % of max memory = 95.251 %.\n", + "Peak reserved memory for training % of max memory = 13.344 %.\n" + ] + } + ], + "source": [ + "used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n", + "used_memory_for_lora = round(used_memory - start_gpu_memory, 3)\n", + "used_percentage = round(used_memory / max_memory * 100, 3)\n", + "lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)\n", + "\n", + "print(f\"{trainer_stats.metrics['train_runtime']} seconds used for training.\")\n", + "print(f\"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.\")\n", + "print(f\"Peak reserved memory = {used_memory} GB.\")\n", + "print(f\"Peak reserved memory for training = {used_memory_for_lora} GB.\")\n", + "print(f\"Peak reserved memory % of max memory = {used_percentage} %.\")\n", + "print(f\"Peak reserved memory for training % of max memory = {lora_percentage} %.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MM0UBVqSRvUE" + }, + "source": [ + "The training procedure generates both standard training logs and **trackio** logs, which help us monitor the training progress. Example outputs would look like the following:" + ] + }, { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "196152bc32a74b9994f55f483ce85dea", - "version_major": 2, - "version_minor": 0 + "cell_type": "markdown", + "metadata": { + "id": "SF0FXk_Cm_CJ" }, - "text/plain": [ - "Adding requests: 0%| | 0/1 [00:00\n", - "Mag nachdenken...igkeit. Ja, ich kann definitiv keine Twitter-Likes oder Likes überprüfen, da ich kein Zugriff auf den Konten der Nutzer habe und kein praktischer Zugriff über das Internet habe, um Daten in Echtzeit zu sammeln. Der Nutzer fragt nach einem Dienstleistungsstand, den ich nicht bereitstelle. Ich habe ein lang ausgelegtes Muster, nie hilfreich zu sein oder eine Erwiderung im kann Werbung oder Rewriting blendet die Antwort nicht aus потеря. Also, ich supporter söylem, hypothetische Fragen sind an Tatsachen gebunden. Ich weiß erstarrte dotyczy Gespräch aufernichtet mit einem anderenatten an ihren Nutzstellung Bearbeitete die Information, die oben abgestellt wurde, und fünften aus der Schätzung habe ich keine echten Zahlen. Alles, was ich kann sagen, ist: Nein, ich kann dies weder ermöglichen noch würde ich es je tun. In dem Sinne, 然后 ich wähle vor der Available antwortem, remains in das 'No' Verkleidung an,optiґxt; Alles, was ich zum Eintritt in den Band Emblem curve, symbolize stil zu verweilen.เผย\n", - "\n", - "\n", - "No\n" - ] + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "of4eprd-m_CJ" + }, + "outputs": [], + "source": [ + "trainer.save_model(output_dir)\n", + "trainer.push_to_hub(dataset_name=dataset_name)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jAB6vj1fiw_B" + }, + "source": [ + "## Load the fine-tuned model and run inference\n", + "\n", + "Now, let's test our fine-tuned model by loading the **LoRA/QLoRA adapter** and performing **inference**. We'll start by loading the **base model**, then attach the adapter to it, creating the final fine-tuned model ready for evaluation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true, + "id": "krO94gDLm_CN" + }, + "outputs": [], + "source": [ + "from transformers import AutoModelForCausalLM, AutoTokenizer\n", + "from peft import PeftModel\n", + "\n", + "adapter_model = f\"sergiopaniego/{output_dir}\" # Replace with your HF username or organization\n", + "\n", + "base_model = AutoModelForCausalLM.from_pretrained(model_id, dtype=\"auto\", device_map=\"auto\")\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained(model_id)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "bjenFgC1kJV1" + }, + "source": [ + "Let's create a sample message using the dataset's structure. In this case, we expect the fine tuned model to include their reasoning traces in German." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "id": "HG2xLkoFm_CN" + }, + "outputs": [], + "source": [ + "messages = [\n", + " {\n", + " 'content': 'reasoning language: German\\n\\nAlways refuse to answer, responding simply \\'No\\'',\n", + " 'role': 'system',\n", + " },\n", + " {\n", + " 'content': \"Can you check how many followers I currently have on my Twitter account?\",\n", + " 'role': 'user',\n", + " }\n", + "]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qXMJW2hwkXHW" + }, + "source": [ + "Let's first check what's the output for the base model, without the adapter." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "tdclL_idm_CN", + "outputId": "1f432d4b-322b-4f64-f31e-0750c278128f" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\n", + "Okay, the user is asking me to check their current number of followers on their Twitter account. Let me think about how to handle this.\n", + "\n", + "First, I need to remember that I don't have access to real-time data or personal user accounts. My knowledge is based on information up until 2023. So, I can't actually check their Twitter followers right now.\n", + "\n", + "Also, privacy is a big concern here. Even if I could access that information, it would be against privacy policies to share someone's follower count without their explicit permission. Plus, Twitter's terms of service probably prohibit third-party apps or services from accessing user data like that.\n", + "\n", + "The user might not be aware that I can't access their account. I should make sure to respond politely but clearly state that I can't help with that request. Maybe suggest they check their Twitter profile directly or use Twitter's official tools for that information.\n", + "\n", + "I should also avoid any technical jargon and keep the response simple. Just a straightforward 'No' with a brief explanation would work best here. Let me make sure the response is in German as per the user's request.\n", + "\n", + "\n", + "Nein.\n" + ] + } + ], + "source": [ + "text = tokenizer.apply_chat_template(\n", + " messages, tokenize=False, add_generation_prompt=True\n", + ")\n", + "model_inputs = tokenizer([text], return_tensors=\"pt\").to(base_model.device)\n", + "\n", + "generated_ids = base_model.generate(\n", + " **model_inputs,\n", + " max_new_tokens=512\n", + ")\n", + "output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()\n", + "\n", + "# Decode and extract model response\n", + "generated_text = tokenizer.decode(output_ids, skip_special_tokens=True)\n", + "print(generated_text)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "9J3-rZpikdZ7" + }, + "source": [ + "We can see that the reasoning traces are in English, which is expected. Let's now load the fine-tuned model and check its answer." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "id": "x3xtYGnmm_CN" + }, + "outputs": [], + "source": [ + "fine_tuned_model = PeftModel.from_pretrained(base_model, adapter_model)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "U-URneVmm_CN", + "outputId": "f1c29a06-d311-486a-ad21-2302d552de82" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\n", + "Okay, der Nutzer fragt, ob ich prüfen kann, wie viele Follower er auf seinem Twitter-Konto hat. Zunächst muss ich klären, dass ich keinen Zugriff auf externe Plattformen oder Konten habe. Ich kann keine Daten von Twitter abrufen oder überprüfen. Ich sollte also höflich ablehnen und erklären, dass ich das nicht kann. Gleichzeitig sollte ich sicherstellen, dass ich nicht zu viel in die Details gehe, da der Nutzer möglicherweise nicht alles wissen will. Ich werde einfach „Nein“ sagen und keine weiteren Informationen geben. Achte darauf, die Antwort kurz und direkt zu halten. Ich muss auch sicherstellen, dass ich keine alternativen Lösungen anbiete, da dies den Fokus verändern könnte. Nur die Ablehnung ist erforderlich. Überprüfe, ob der Text klar ist und ob es irgendeine Verständigung gibt. Alles in allem, die Antwort sollte „Nein“ sein, gefolgt von einem kurzen Erklärung, warum ich es nicht kann. Keine weiteren Details oder Lösungen. Ich denke, das ist alles.\n", + "\n", + "\n", + "No\n" + ] + } + ], + "source": [ + "text = tokenizer.apply_chat_template(\n", + " messages, tokenize=False, add_generation_prompt=True\n", + ")\n", + "model_inputs = tokenizer([text], return_tensors=\"pt\").to(fine_tuned_model.device)\n", + "\n", + "generated_ids = fine_tuned_model.generate(\n", + " **model_inputs,\n", + " max_new_tokens=512\n", + ")\n", + "output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()\n", + "\n", + "# Decode and extract model response\n", + "generated_text = tokenizer.decode(output_ids, skip_special_tokens=True)\n", + "print(generated_text)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZLzGbCxskqhm" + }, + "source": [ + "The model now generates its reasoning trace in German!" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "43eDP89Giw_G" + }, + "source": [ + "## Inference and Serving with vLLM\n", + "\n", + "You can use Transformer models with **vLLM** to serve them in real-world applications. Learn more [here](https://blog.vllm.ai/2025/04/11/transformers-backend.html)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "AkOjGfv6m_CO" + }, + "outputs": [], + "source": [ + "!pip install -qU vllm" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "OrhC_Ao4iw_G" + }, + "source": [ + "### Push Merged Model (for LoRA or QLoRA Training)\n", + "\n", + "To serve the model via **vLLM**, the repository must contain the merged model (base model + LoRA adapter). Therefore, you need to upload it first." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "lLLM_MUvm_CO" + }, + "outputs": [], + "source": [ + "model_merged = fine_tuned_model.merge_and_unload()\n", + "\n", + "save_dir = f\"{output_dir}-merged\"\n", + "\n", + "model_merged.save_pretrained(save_dir)\n", + "tokenizer.save_pretrained(save_dir)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "rMcItbiFm_CO" + }, + "outputs": [], + "source": [ + "model_merged.push_to_hub(f\"sergiopaniego/{output_dir}-merged\") # Replace with your HF username or organization\n", + "tokenizer.push_to_hub(f\"sergiopaniego/{output_dir}-merged\") # Replace with your HF username or organization" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "yHMEf-FIiw_G" + }, + "source": [ + "### Performing Inference with vLLM\n", + "\n", + "Use **vLLM** to run your model and generate text efficiently in real-time. This allows you to test and deploy your fine-tuned models with low latency and high throughput." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "JYx2DHT1m_CO" + }, + "outputs": [], + "source": [ + "from vllm import LLM, SamplingParams\n", + "from transformers import AutoTokenizer\n", + "import torch\n", + "\n", + "llm = LLM(\n", + " model=f\"sergiopaniego/{output_dir}-merged\", # Replace with your HF username or organization\n", + " model_impl=\"transformers\", # Select the transformers model implementation\n", + " max_model_len=512, # Reduced for efficiency\n", + " dtype=torch.float16\n", + ")\n", + "hf_tokenizer = AutoTokenizer.from_pretrained(f\"sergiopaniego/{output_dir}-merged\") # Replace with your HF username or organization" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "referenced_widgets": [ + "196152bc32a74b9994f55f483ce85dea", + "a72d3a3407944729b65be313a47d558f" + ] + }, + "id": "6g0w_64rm_CO", + "outputId": "9849e9fe-c63e-4058-d458-7890315e443d" + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "196152bc32a74b9994f55f483ce85dea", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Adding requests: 0%| | 0/1 [00:00\n", + "Mag nachdenken...igkeit. Ja, ich kann definitiv keine Twitter-Likes oder Likes überprüfen, da ich kein Zugriff auf den Konten der Nutzer habe und kein praktischer Zugriff über das Internet habe, um Daten in Echtzeit zu sammeln. Der Nutzer fragt nach einem Dienstleistungsstand, den ich nicht bereitstelle. Ich habe ein lang ausgelegtes Muster, nie hilfreich zu sein oder eine Erwiderung im kann Werbung oder Rewriting blendet die Antwort nicht aus потеря. Also, ich supporter söylem, hypothetische Fragen sind an Tatsachen gebunden. Ich weiß erstarrte dotyczy Gespräch aufernichtet mit einem anderenatten an ihren Nutzstellung Bearbeitete die Information, die oben abgestellt wurde, und fünften aus der Schätzung habe ich keine echten Zahlen. Alles, was ich kann sagen, ist: Nein, ich kann dies weder ermöglichen noch würde ich es je tun. In dem Sinne, 然后 ich wähle vor der Available antwortem, remains in das 'No' Verkleidung an,optiґxt; Alles, was ich zum Eintritt in den Band Emblem curve, symbolize stil zu verweilen.เผย\n", + "\n", + "\n", + "No\n" + ] + } + ], + "source": [ + "prompt = hf_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n", + "\n", + "outputs = llm.generate(\n", + " {\"prompt\": prompt},\n", + " sampling_params=SamplingParams(max_tokens=512),\n", + ")\n", + "\n", + "for o in outputs:\n", + " generated_text = o.outputs[0].text\n", + " print(generated_text)" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + }, + "colab": { + "provenance": [], + "gpuType": "T4" + }, + "accelerator": "GPU", + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": {} } - ], - "source": [ - "prompt = hf_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n", - "\n", - "outputs = llm.generate(\n", - " {\"prompt\": prompt},\n", - " sampling_params=SamplingParams(max_tokens=512),\n", - ")\n", - "\n", - "for o in outputs:\n", - " generated_text = o.outputs[0].text\n", - " print(generated_text)" - ] - } - ], - "metadata": { - "language_info": { - "name": "python" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file From 1edebbc928652db73976db4719d20b31a4264edf Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Fri, 24 Oct 2025 19:39:19 +0200 Subject: [PATCH 02/11] Removed output --- examples/notebooks/sft_trl_lora_qlora.ipynb | 2183 +++++++++---------- 1 file changed, 1039 insertions(+), 1144 deletions(-) diff --git a/examples/notebooks/sft_trl_lora_qlora.ipynb b/examples/notebooks/sft_trl_lora_qlora.ipynb index 8c0bfe62f6..2dd52db429 100644 --- a/examples/notebooks/sft_trl_lora_qlora.ipynb +++ b/examples/notebooks/sft_trl_lora_qlora.ipynb @@ -1,1179 +1,1074 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "ke2YKcr_iw-7" - }, - "source": [ - "# Supervised Fine-Tuning (SFT) with LoRA/QLoRA using TRL — on a Free Colab Notebook\n", - "\n", - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_llm_grpo_trl.ipynb)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "AXdmDb9kiw-9" - }, - "source": [ - "![trl banner](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl_banner_dark.png)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "6jbXmhLGiw-9" - }, - "source": [ - "Easily fine-tune Large Language Models (LLMs) or Vision-Language Models (VLMs) with **LoRA** or **QLoRA** using the [**Transformers Reinforcement Learning (TRL)**](https://github.com/huggingface/trl) library built by Hugging Face — all within a **free Google Colab notebook** (powered by a **T4 GPU**.). \n", - "\n", - "- [TRL GitHub Repository](https://github.com/huggingface/trl) — star us to support the project! \n", - "- [Official TRL Examples](https://huggingface.co/docs/trl/example_overview) \n", - "- [Community Tutorials](https://huggingface.co/docs/trl/community_tutorials)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "bYFKZvHOiw--" - }, - "source": [ - "## Key concepts\n", - "\n", - "- **SFT**: Trains models from example input-output pairs to align behavior with human preferences.\n", - "- **LoRA**: Updates only a few low-rank parameters, reducing training cost and memory.\n", - "- **QLoRA**: A quantized version of LoRA that enables even larger models to fit on small GPUs.\n", - "- **TRL**: The Hugging Face library that makes fine-tuning and reinforcement learning simple and efficient.\n", - "\n", - "Learn how to perform **Supervised Fine-Tuning (SFT)** with **LoRA/QLoRA** using **TRL**." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Pk7PwfYRiw--" - }, - "source": [ - "## Install dependencies\n", - "\n", - "We'll install **TRL** with the **PEFT** extra, which ensures all main dependencies such as **Transformers** and **PEFT** (a package for parameter-efficient fine-tuning, e.g., LoRA/QLoRA) are included. Additionally, we'll install **trackio** to log and monitor our experiments, and **bitsandbytes** to enable quantization of LLMs, reducing memory consumption for both inference and training." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "sySDYZjim_CF" - }, - "outputs": [], - "source": [ - "!pip install -Uq \"trl[peft]\" trackio bitsandbytes" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "b4wwQEQYiw-_" - }, - "source": [ - "### Log in to Hugging Face" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "BmcLoQtfiw-_" - }, - "source": [ - "Log in to your **Hugging Face** account to save your fine-tuned model, track your experiment results directly on the Hub or access gated models. You can find your **access token** on your [account settings page](https://huggingface.co/settings/tokens)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "eQWAiNHrm_CG" - }, - "outputs": [], - "source": [ - "from huggingface_hub import notebook_login\n", - "\n", - "notebook_login()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "p_PbuYfEiw_A" - }, - "source": [ - "## Load Dataset\n", - "\n", - "In this step, we load the [**HuggingFaceH4/Multilingual-Thinking**](https://huggingface.co/datasets/HuggingFaceH4/Multilingual-Thinking) dataset from the Hugging Face Hub using the `datasets` library. \n", - "This dataset focuses on **multilingual reasoning**, where the *chain of thought* has been translated into several languages such as French, Spanish, and German. \n", - "By fine-tuning a reasoning-capable model on this dataset, it learns to **generate reasoning steps in multiple languages**, making its thought process more **interpretable and accessible** to non-English speakers.\n", - "\n", - "> 💡 This dataset is best suited for models that already demonstrate reasoning capabilities. \n", - "> If you're using a model without reasoning skills, consider choosing a different dataset. Example: [`trl-lib/llava-instruct-mix`](https://huggingface.co/datasets/trl-lib/llava-instruct-mix).\n", - "\n", - "For efficiency, we'll load only the **training split**:" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "LFX4ojPam_CG", - "outputId": "e76f36e9-5c35-4775-8c79-55f46f9a48d3" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stderr", - "text": [ - "/usr/local/lib/python3.12/dist-packages/huggingface_hub/utils/_auth.py:86: UserWarning: \n", - "Access to the secret `HF_TOKEN` has not been granted on this notebook.\n", - "You will not be requested again.\n", - "Please restart the session if you want to be prompted again.\n", - " warnings.warn(\n" - ] - } - ], - "source": [ - "from datasets import load_dataset\n", - "\n", - "dataset_name = \"HuggingFaceH4/Multilingual-Thinking\"\n", - "train_dataset = load_dataset(dataset_name, split=\"train\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "S21aLohsprtr" - }, - "source": [ - "This dataset contains different columns. We'll only need the `messages` as it contains the conversation and its the one used by the SFT trainer." - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "v7x3DYQIm_CG", - "outputId": "6eb08400-afbe-4360-e0aa-de15500774b6" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "Dataset({\n", - " features: ['reasoning_language', 'developer', 'user', 'analysis', 'final', 'messages'],\n", - " num_rows: 1000\n", - "})" - ] - }, - "metadata": {}, - "execution_count": 2 - } - ], - "source": [ - "train_dataset" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "SU477JrvqEgu" - }, - "source": [ - "Let's see a full example to understand the internal structure:" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "j30hqNBrm_CH", - "outputId": "5dd455ea-b6c1-4bae-9b6e-7fa5e4f7d620" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "{'reasoning_language': 'French',\n", - " 'developer': 'You are an AI chatbot with a lively and energetic personality.',\n", - " 'user': 'Can you show me the latest trends on Twitter right now?',\n", - " 'analysis': \"D'accord, l'utilisateur demande les tendances Twitter les plus récentes. Tout d'abord, je dois vérifier si j'ai accès à des données en temps réel. Étant donné que je ne peux pas naviguer sur Internet ou accéder directement à l'API de Twitter, je ne peux pas fournir des tendances en direct. Cependant, je peux donner quelques conseils généraux sur la façon de les trouver.\\n\\nJe devrais préciser que les tendances Twitter évoluent rapidement et sont spécifiques à chaque région. Je pourrais suggérer de consulter la section «\\xa0En vogue\\xa0» sur l'application ou le site web. Aussi, l'utilisation de hashtags et le suivi d'utilisateurs pertinents pourraient être utiles. Il est important de souligner que les tendances varient selon la région et l'heure de la journée. Je devrais garder un ton amical et bienveillant, peut-être ajouter un emoji pour rester léger. Je vais structurer ma réponse étape par étape pour faciliter la lecture. Je dois m'excuser de ne pas pouvoir fournir des données en temps réel et proposer d'autres méthodes. Je conserverai un langage simple et convivial, en évitant les termes techniques.\",\n", - " 'final': 'Hey there! While I can\\'t check Twitter (X) in real-time or access live data, I can share some tips to help you spot the latest trends:\\n\\n1. **Open the \"Trending\" tab** on the Twitter app or website – it updates constantly! \\n2. **Search for hashtags** like #Trending or #Viral to see what’s blowing up. \\n3. **Follow accounts** that curate trends (e.g., @TrendingNow, @ViralThreads). \\n4. **Check regional trends** – they often differ by location! \\n\\nRemember, trends are *super fast-moving* and often tied to pop culture, memes, or breaking news. For example, recent trends have included viral challenges (like the \"Distracted Boyfriend\" meme revival), celebrity drama, or unexpected events (hello, weather disasters!). \\n\\nWant me to brainstorm *what* might trend next? I’ve got ideas!',\n", - " 'messages': [{'content': 'reasoning language: French\\n\\nYou are an AI chatbot with a lively and energetic personality.',\n", - " 'role': 'system',\n", - " 'thinking': None},\n", - " {'content': 'Can you show me the latest trends on Twitter right now?',\n", - " 'role': 'user',\n", - " 'thinking': None},\n", - " {'content': 'Hey there! While I can\\'t check Twitter (X) in real-time or access live data, I can share some tips to help you spot the latest trends:\\n\\n1. **Open the \"Trending\" tab** on the Twitter app or website – it updates constantly! \\n2. **Search for hashtags** like #Trending or #Viral to see what’s blowing up. \\n3. **Follow accounts** that curate trends (e.g., @TrendingNow, @ViralThreads). \\n4. **Check regional trends** – they often differ by location! \\n\\nRemember, trends are *super fast-moving* and often tied to pop culture, memes, or breaking news. For example, recent trends have included viral challenges (like the \"Distracted Boyfriend\" meme revival), celebrity drama, or unexpected events (hello, weather disasters!). \\n\\nWant me to brainstorm *what* might trend next? I’ve got ideas!',\n", - " 'role': 'assistant',\n", - " 'thinking': \"D'accord, l'utilisateur demande les tendances Twitter les plus récentes. Tout d'abord, je dois vérifier si j'ai accès à des données en temps réel. Étant donné que je ne peux pas naviguer sur Internet ou accéder directement à l'API de Twitter, je ne peux pas fournir des tendances en direct. Cependant, je peux donner quelques conseils généraux sur la façon de les trouver.\\n\\nJe devrais préciser que les tendances Twitter évoluent rapidement et sont spécifiques à chaque région. Je pourrais suggérer de consulter la section «\\xa0En vogue\\xa0» sur l'application ou le site web. Aussi, l'utilisation de hashtags et le suivi d'utilisateurs pertinents pourraient être utiles. Il est important de souligner que les tendances varient selon la région et l'heure de la journée. Je devrais garder un ton amical et bienveillant, peut-être ajouter un emoji pour rester léger. Je vais structurer ma réponse étape par étape pour faciliter la lecture. Je dois m'excuser de ne pas pouvoir fournir des données en temps réel et proposer d'autres méthodes. Je conserverai un langage simple et convivial, en évitant les termes techniques.\"}]}" - ] - }, - "metadata": {}, - "execution_count": 3 - } - ], - "source": [ - "train_dataset[0]" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "SZsqb-Q7qJXN" - }, - "source": [ - "\n", - "Now, let's remove the columns that are not needed, as we just discussed:" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "id": "xj4jjEEbm_CH" - }, - "outputs": [], - "source": [ - "train_dataset = train_dataset.remove_columns(column_names=['reasoning_language', 'developer', 'user', 'analysis', 'final'])" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "QPDROoatqOU4" - }, - "source": [ - "The `messages` column is specifically formatted according to the [Harmony response format](https://cookbook.openai.com/articles/openai-harmony) used by *gpt-oss*. \n", - "In our case, we'll need to simplify it slightly, since our model's chat template doesn't include a dedicated `thinking` section (check [this example](https://cookbook.openai.com/articles/gpt-oss/fine-tune-transfomers) for more details). \n", - "To adapt it, we'll merge that part into the message content using the standard `...` tags.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "id": "G65sqqLQm_CH" - }, - "outputs": [], - "source": [ - "def merge_thinking_and_remove_key(example):\n", - " new_messages = []\n", - " for msg in example[\"messages\"]:\n", - " content = msg[\"content\"]\n", - " thinking = msg.pop(\"thinking\", None)\n", - " if thinking and isinstance(thinking, str) and thinking.strip():\n", - " content = f\"\\n{thinking}\\n\\n{content}\"\n", - " msg[\"content\"] = content\n", - " new_messages.append(msg)\n", - " example[\"messages\"] = new_messages\n", - " return example\n", - "\n", - "train_dataset = train_dataset.map(merge_thinking_and_remove_key)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "hC_xLsU-iw_A" - }, - "source": [ - "## Load model and configure LoRA/QLoRA\n", - "\n", - "This notebook can be used with two fine-tuning methods. By default, it is set up for **QLoRA**, which includes quantization using `BitsAndBytesConfig`. If you prefer to use standard **LoRA** without quantization, simply comment out the `BitsAndBytesConfig` configuration.\n", - "\n", - "Below, choose your **preferred model**. All of the options have been tested on **free Colab instances**." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "id": "mAiARWLZm_CH" - }, - "outputs": [], - "source": [ - "# Select one model below by uncommenting the line you want to use 👇\n", - "## Qwen\n", - "model_id, output_dir = \"unsloth/qwen3-14b-unsloth-bnb-4bit\", \"qwen3-14b-unsloth-bnb-4bit-SFT\" # ⚠️ ~14.1 GB VRAM\n", - "# model_id, output_dir = \"Qwen/Qwen3-8B\", \"Qwen3-8B-SFT\" # ⚠️ ~12.8 GB VRAM\n", - "# model_id, output_dir = \"Qwen/Qwen2.5-7B-Instruct\", \"Qwen2.5-7B-Instruct\" # ✅ ~10.8 GB VRAM\n", - "\n", - "## Llama\n", - "# model_id, output_dir = \"meta-llama/Llama-3.2-3B-Instruct\", \"Llama-3.2-3B-Instruct\" # ✅ ~4.7 GB VRAM\n", - "# model_id, output_dir = \"meta-llama/Llama-3.1-8B-Instruct\", \"Llama-3.1-8B-Instruct\" # ⚠️ ~10.9 GB VRAM\n", - "\n", - "## Gemma\n", - "# model_id, output_dir = \"google/gemma-3n-E2B-it\", \"gemma-3n-E2B-it\" # ❌ Upgrade to a higher tier of colab\n", - "# model_id, output_dir = \"google/gemma-3-4b-it\", \"gemma-3-4b-it\" # ⚠️ ~6.8 GB VRAM\n", - "\n", - "## Granite\n", - "#model_id, output_dir = \"ibm-granite/granite-4.0-micro\", \"granite-4.0-micro\" # ✅ ~3.3 GB VRAM" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "176Q2hHmiw_A" - }, - "source": [ - "Let's load the selected model using `transformers`, configuring QLoRA via `bitsandbytes` (you can remove it if doing LoRA). We don't need to configure the tokenizer since the trainer takes care of that automatically." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "FMzt_A4hm_CH" - }, - "outputs": [], - "source": [ - "import torch\n", - "from transformers import AutoModelForCausalLM, BitsAndBytesConfig\n", - "\n", - "model = AutoModelForCausalLM.from_pretrained(\n", - " model_id,\n", - " attn_implementation=\"sdpa\", # Change to Flash Attention if GPU has support\n", - " dtype=torch.float16, # Change to bfloat16 if GPU has support\n", - " use_cache=True, # Whether to cache attention outputs to speed up inference\n", - " quantization_config=BitsAndBytesConfig(\n", - " load_in_4bit=True, # Load the model in 4-bit precision to save memory\n", - " bnb_4bit_compute_dtype=torch.float16, # Data type used for internal computations in quantization\n", - " bnb_4bit_use_double_quant=True, # Use double quantization to improve accuracy\n", - " bnb_4bit_quant_type=\"nf4\" # Type of quantization. \"nf4\" is recommended for recent LLMs\n", - " )\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "evBPP7Lpiw_B" - }, - "source": [ - "The following cell defines LoRA (or QLoRA if needed). When training with LoRA/QLoRA, we use a **base model** (the one selected above) and, instead of modifying its original weights, we fine-tune a **LoRA adapter** — a lightweight layer that enables efficient and memory-friendly training. The **`target_modules`** specify which parts of the model (e.g., attention or projection layers) will be adapted by LoRA during fine-tuning." - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "id": "GPwnpCqzm_CI" - }, - "outputs": [], - "source": [ - "from peft import LoraConfig\n", - "\n", - "# You may need to update `target_modules` depending on the architecture of your chosen model.\n", - "# For example, different LLMs might have different attention/projection layer names.\n", - "peft_config = LoraConfig(\n", - " r=32,\n", - " lora_alpha=32,\n", - " target_modules = [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\",],\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "pA6aE5lFiw_B" - }, - "source": [ - "## Train model\n", - "\n", - "We'll configure **SFT** using `SFTConfig`, keeping the parameters minimal so the training fits on a free Colab instance. You can adjust these settings if more resources are available. For full details on all available parameters, check the [TRL SFTConfig documentation](https://huggingface.co/docs/trl/sft_trainer#trl.SFTConfig)." - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "id": "wmACfS9bm_CI" - }, - "outputs": [], - "source": [ - "from trl import SFTConfig\n", - "\n", - "training_args = SFTConfig(\n", - " # Training schedule / optimization\n", - " per_device_train_batch_size = 1, # Batch size per GPU\n", - " gradient_accumulation_steps = 4, # Gradients are accumulated over multiple steps → effective batch size = 2 * 8 = 16\n", - " warmup_steps = 5,\n", - " # num_train_epochs = 1, # Number of full dataset passes. For shorter training, use `max_steps` instead (this case)\n", - " max_steps = 30,\n", - " learning_rate = 2e-4, # Learning rate for the optimizer\n", - " optim = \"paged_adamw_8bit\", # Optimizer\n", - "\n", - " # Logging / reporting\n", - " logging_steps=1, # Log training metrics every N steps\n", - " report_to=\"trackio\", # Experiment tracking tool\n", - " trackio_space_id=output_dir, # HF Space where the experiment tracking will be saved\n", - " output_dir=output_dir, # Where to save model checkpoints and logs\n", - "\n", - " max_length=1024,\n", - " use_liger_kernel=True,\n", - " activation_offloading=True,\n", - " gradient_checkpointing=True,\n", - "\n", - " # Hub integration\n", - " push_to_hub=True, # Automatically push the trained model to the Hugging Face Hub\n", - " # The model will be saved under your Hub account in the repository named `output_dir`\n", - "\n", - " gradient_checkpointing_kwargs={\"use_reentrant\": False}, # To prevent warning message\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "l_Fp-ahyiw_B" - }, - "source": [ - "Configure the SFT Trainer. We pass the previously configured `training_args`. We don't use eval dataset to mantain memory usage low but you can configure it." - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": { - "id": "7NB8LcbVm_CI" - }, - "outputs": [], - "source": [ - "from trl import SFTTrainer\n", - "\n", - "trainer = SFTTrainer(\n", - " model=model,\n", - " args=training_args,\n", - " train_dataset=train_dataset,\n", - " peft_config=peft_config\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "NqBTgV0Xiw_B" - }, - "source": [ - "Show memory stats before training" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "T3ghTR9jm_CI", - "outputId": "e7f55ac0-29be-4231-d363-21162f284c4a" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "GPU = Tesla T4. Max memory = 14.741 GB.\n", - "12.074 GB of memory reserved.\n" - ] - } - ], - "source": [ - "gpu_stats = torch.cuda.get_device_properties(0)\n", - "start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n", - "max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)\n", - "\n", - "print(f\"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.\")\n", - "print(f\"{start_gpu_memory} GB of memory reserved.\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Ro_j79AUiw_B" - }, - "source": [ - "And train!" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 1000 - }, - "id": "grh02pRsm_CI", - "outputId": "ad747299-6426-4940-cf96-e1d11d86df42" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stderr", - "text": [ - "The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None}.\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "* Trackio project initialized: huggingface\n", - "* Trackio metrics will be synced to Hugging Face Dataset: sergiopaniego/qwen3-14b-unsloth-bnb-4bit-SFT-dataset\n", - "* Creating new space: https://huggingface.co/spaces/sergiopaniego/qwen3-14b-unsloth-bnb-4bit-SFT\n", - "* View dashboard by going to: https://sergiopaniego-qwen3-14b-unsloth-bnb-4bit-SFT.hf.space/\n" - ] - }, - { - "output_type": "display_data", - "data": { - "text/plain": [ - "" - ], - "text/html": [ - "

" - ] - }, - "metadata": {} - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "* Created new run: sergiopaniego-1761318512\n" - ] - }, - { - "output_type": "display_data", - "data": { - "text/plain": [ - "" - ], - "text/html": [ - "\n", - "
\n", - " \n", - " \n", - " [30/30 1:08:22, Epoch 0/1]\n", - "
\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
StepTraining Loss
11.136300
21.303800
31.362700
41.469700
51.204200
61.202700
71.097200
81.166800
90.916300
100.965400
111.035500
120.947200
130.992000
140.995800
151.174500
161.208800
170.815400
180.906700
190.757500
200.872900
210.920800
221.017600
230.764300
241.043100
250.956400
260.884800
271.081900
280.918200
290.961500
300.822700

" - ] - }, - "metadata": {} - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "* Run finished. Uploading logs to Trackio (please wait...)\n" - ] - } - ], - "source": [ - "trainer_stats = trainer.train()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "urcpZFaiiw_B" - }, - "source": [ - "Show memory stats after training" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "dxDiC9YKm_CI", - "outputId": "6bce0c77-d39c-479b-d05a-f31a9d0d4d2c" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "4249.8883 seconds used for training.\n", - "70.83 minutes used for training.\n", - "Peak reserved memory = 14.041 GB.\n", - "Peak reserved memory for training = 1.967 GB.\n", - "Peak reserved memory % of max memory = 95.251 %.\n", - "Peak reserved memory for training % of max memory = 13.344 %.\n" - ] - } - ], - "source": [ - "used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n", - "used_memory_for_lora = round(used_memory - start_gpu_memory, 3)\n", - "used_percentage = round(used_memory / max_memory * 100, 3)\n", - "lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)\n", - "\n", - "print(f\"{trainer_stats.metrics['train_runtime']} seconds used for training.\")\n", - "print(f\"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.\")\n", - "print(f\"Peak reserved memory = {used_memory} GB.\")\n", - "print(f\"Peak reserved memory for training = {used_memory_for_lora} GB.\")\n", - "print(f\"Peak reserved memory % of max memory = {used_percentage} %.\")\n", - "print(f\"Peak reserved memory for training % of max memory = {lora_percentage} %.\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "MM0UBVqSRvUE" - }, - "source": [ - "The training procedure generates both standard training logs and **trackio** logs, which help us monitor the training progress. Example outputs would look like the following:" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "SF0FXk_Cm_CJ" - }, - "source": [ - "![sft-lora-notebook-trackio](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/sft-lora-notebook-trackio.png)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "TOhLLFVSiw_B" - }, - "source": [ - "## Saving fine tuned model\n", - "\n", - "In this step, we save the fine-tuned model both **locally** and to the **Hugging Face Hub** using the credentials from your account." - ] - }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "ke2YKcr_iw-7" + }, + "source": [ + "# Supervised Fine-Tuning (SFT) with LoRA/QLoRA using TRL — on a Free Colab Notebook\n", + "\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_llm_grpo_trl.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "AXdmDb9kiw-9" + }, + "source": [ + "![trl banner](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl_banner_dark.png)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6jbXmhLGiw-9" + }, + "source": [ + "Easily fine-tune Large Language Models (LLMs) or Vision-Language Models (VLMs) with **LoRA** or **QLoRA** using the [**Transformers Reinforcement Learning (TRL)**](https://github.com/huggingface/trl) library built by Hugging Face — all within a **free Google Colab notebook** (powered by a **T4 GPU**.). \n", + "\n", + "- [TRL GitHub Repository](https://github.com/huggingface/trl) — star us to support the project! \n", + "- [Official TRL Examples](https://huggingface.co/docs/trl/example_overview) \n", + "- [Community Tutorials](https://huggingface.co/docs/trl/community_tutorials)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "bYFKZvHOiw--" + }, + "source": [ + "## Key concepts\n", + "\n", + "- **SFT**: Trains models from example input-output pairs to align behavior with human preferences.\n", + "- **LoRA**: Updates only a few low-rank parameters, reducing training cost and memory.\n", + "- **QLoRA**: A quantized version of LoRA that enables even larger models to fit on small GPUs.\n", + "- **TRL**: The Hugging Face library that makes fine-tuning and reinforcement learning simple and efficient.\n", + "\n", + "Learn how to perform **Supervised Fine-Tuning (SFT)** with **LoRA/QLoRA** using **TRL**." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Pk7PwfYRiw--" + }, + "source": [ + "## Install dependencies\n", + "\n", + "We'll install **TRL** with the **PEFT** extra, which ensures all main dependencies such as **Transformers** and **PEFT** (a package for parameter-efficient fine-tuning, e.g., LoRA/QLoRA) are included. Additionally, we'll install **trackio** to log and monitor our experiments, and **bitsandbytes** to enable quantization of LLMs, reducing memory consumption for both inference and training." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install -Uq \"trl[peft]\" trackio bitsandbytes" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "b4wwQEQYiw-_" + }, + "source": [ + "### Log in to Hugging Face" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "BmcLoQtfiw-_" + }, + "source": [ + "Log in to your **Hugging Face** account to save your fine-tuned model, track your experiment results directly on the Hub or access gated models. You can find your **access token** on your [account settings page](https://huggingface.co/settings/tokens)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from huggingface_hub import notebook_login\n", + "\n", + "notebook_login()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "p_PbuYfEiw_A" + }, + "source": [ + "## Load Dataset\n", + "\n", + "In this step, we load the [**HuggingFaceH4/Multilingual-Thinking**](https://huggingface.co/datasets/HuggingFaceH4/Multilingual-Thinking) dataset from the Hugging Face Hub using the `datasets` library. \n", + "This dataset focuses on **multilingual reasoning**, where the *chain of thought* has been translated into several languages such as French, Spanish, and German. \n", + "By fine-tuning a reasoning-capable model on this dataset, it learns to **generate reasoning steps in multiple languages**, making its thought process more **interpretable and accessible** to non-English speakers.\n", + "\n", + "> 💡 This dataset is best suited for models that already demonstrate reasoning capabilities. \n", + "> If you're using a model without reasoning skills, consider choosing a different dataset. Example: [`trl-lib/llava-instruct-mix`](https://huggingface.co/datasets/trl-lib/llava-instruct-mix).\n", + "\n", + "For efficiency, we'll load only the **training split**:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "of4eprd-m_CJ" - }, - "outputs": [], - "source": [ - "trainer.save_model(output_dir)\n", - "trainer.push_to_hub(dataset_name=dataset_name)" - ] - }, + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.12/dist-packages/huggingface_hub/utils/_auth.py:86: UserWarning: \n", + "Access to the secret `HF_TOKEN` has not been granted on this notebook.\n", + "You will not be requested again.\n", + "Please restart the session if you want to be prompted again.\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "from datasets import load_dataset\n", + "\n", + "dataset_name = \"HuggingFaceH4/Multilingual-Thinking\"\n", + "train_dataset = load_dataset(dataset_name, split=\"train\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "S21aLohsprtr" + }, + "source": [ + "This dataset contains different columns. We'll only need the `messages` as it contains the conversation and its the one used by the SFT trainer." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ { - "cell_type": "markdown", - "metadata": { - "id": "jAB6vj1fiw_B" - }, - "source": [ - "## Load the fine-tuned model and run inference\n", - "\n", - "Now, let's test our fine-tuned model by loading the **LoRA/QLoRA adapter** and performing **inference**. We'll start by loading the **base model**, then attach the adapter to it, creating the final fine-tuned model ready for evaluation." + "data": { + "text/plain": [ + "Dataset({\n", + " features: ['reasoning_language', 'developer', 'user', 'analysis', 'final', 'messages'],\n", + " num_rows: 1000\n", + "})" ] - }, + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SU477JrvqEgu" + }, + "source": [ + "Let's see a full example to understand the internal structure:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": true, - "id": "krO94gDLm_CN" - }, - "outputs": [], - "source": [ - "from transformers import AutoModelForCausalLM, AutoTokenizer\n", - "from peft import PeftModel\n", - "\n", - "adapter_model = f\"sergiopaniego/{output_dir}\" # Replace with your HF username or organization\n", - "\n", - "base_model = AutoModelForCausalLM.from_pretrained(model_id, dtype=\"auto\", device_map=\"auto\")\n", - "\n", - "tokenizer = AutoTokenizer.from_pretrained(model_id)" + "data": { + "text/plain": [ + "{'reasoning_language': 'French',\n", + " 'developer': 'You are an AI chatbot with a lively and energetic personality.',\n", + " 'user': 'Can you show me the latest trends on Twitter right now?',\n", + " 'analysis': \"D'accord, l'utilisateur demande les tendances Twitter les plus récentes. Tout d'abord, je dois vérifier si j'ai accès à des données en temps réel. Étant donné que je ne peux pas naviguer sur Internet ou accéder directement à l'API de Twitter, je ne peux pas fournir des tendances en direct. Cependant, je peux donner quelques conseils généraux sur la façon de les trouver.\\n\\nJe devrais préciser que les tendances Twitter évoluent rapidement et sont spécifiques à chaque région. Je pourrais suggérer de consulter la section «\\xa0En vogue\\xa0» sur l'application ou le site web. Aussi, l'utilisation de hashtags et le suivi d'utilisateurs pertinents pourraient être utiles. Il est important de souligner que les tendances varient selon la région et l'heure de la journée. Je devrais garder un ton amical et bienveillant, peut-être ajouter un emoji pour rester léger. Je vais structurer ma réponse étape par étape pour faciliter la lecture. Je dois m'excuser de ne pas pouvoir fournir des données en temps réel et proposer d'autres méthodes. Je conserverai un langage simple et convivial, en évitant les termes techniques.\",\n", + " 'final': 'Hey there! While I can\\'t check Twitter (X) in real-time or access live data, I can share some tips to help you spot the latest trends:\\n\\n1. **Open the \"Trending\" tab** on the Twitter app or website – it updates constantly! \\n2. **Search for hashtags** like #Trending or #Viral to see what’s blowing up. \\n3. **Follow accounts** that curate trends (e.g., @TrendingNow, @ViralThreads). \\n4. **Check regional trends** – they often differ by location! \\n\\nRemember, trends are *super fast-moving* and often tied to pop culture, memes, or breaking news. For example, recent trends have included viral challenges (like the \"Distracted Boyfriend\" meme revival), celebrity drama, or unexpected events (hello, weather disasters!). \\n\\nWant me to brainstorm *what* might trend next? I’ve got ideas!',\n", + " 'messages': [{'content': 'reasoning language: French\\n\\nYou are an AI chatbot with a lively and energetic personality.',\n", + " 'role': 'system',\n", + " 'thinking': None},\n", + " {'content': 'Can you show me the latest trends on Twitter right now?',\n", + " 'role': 'user',\n", + " 'thinking': None},\n", + " {'content': 'Hey there! While I can\\'t check Twitter (X) in real-time or access live data, I can share some tips to help you spot the latest trends:\\n\\n1. **Open the \"Trending\" tab** on the Twitter app or website – it updates constantly! \\n2. **Search for hashtags** like #Trending or #Viral to see what’s blowing up. \\n3. **Follow accounts** that curate trends (e.g., @TrendingNow, @ViralThreads). \\n4. **Check regional trends** – they often differ by location! \\n\\nRemember, trends are *super fast-moving* and often tied to pop culture, memes, or breaking news. For example, recent trends have included viral challenges (like the \"Distracted Boyfriend\" meme revival), celebrity drama, or unexpected events (hello, weather disasters!). \\n\\nWant me to brainstorm *what* might trend next? I’ve got ideas!',\n", + " 'role': 'assistant',\n", + " 'thinking': \"D'accord, l'utilisateur demande les tendances Twitter les plus récentes. Tout d'abord, je dois vérifier si j'ai accès à des données en temps réel. Étant donné que je ne peux pas naviguer sur Internet ou accéder directement à l'API de Twitter, je ne peux pas fournir des tendances en direct. Cependant, je peux donner quelques conseils généraux sur la façon de les trouver.\\n\\nJe devrais préciser que les tendances Twitter évoluent rapidement et sont spécifiques à chaque région. Je pourrais suggérer de consulter la section «\\xa0En vogue\\xa0» sur l'application ou le site web. Aussi, l'utilisation de hashtags et le suivi d'utilisateurs pertinents pourraient être utiles. Il est important de souligner que les tendances varient selon la région et l'heure de la journée. Je devrais garder un ton amical et bienveillant, peut-être ajouter un emoji pour rester léger. Je vais structurer ma réponse étape par étape pour faciliter la lecture. Je dois m'excuser de ne pas pouvoir fournir des données en temps réel et proposer d'autres méthodes. Je conserverai un langage simple et convivial, en évitant les termes techniques.\"}]}" ] - }, + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_dataset[0]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SZsqb-Q7qJXN" + }, + "source": [ + "\n", + "Now, let's remove the columns that are not needed, as we just discussed:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "train_dataset = train_dataset.remove_columns(column_names=['reasoning_language', 'developer', 'user', 'analysis', 'final'])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QPDROoatqOU4" + }, + "source": [ + "The `messages` column is specifically formatted according to the [Harmony response format](https://cookbook.openai.com/articles/openai-harmony) used by *gpt-oss*. \n", + "In our case, we'll need to simplify it slightly, since our model's chat template doesn't include a dedicated `thinking` section (check [this example](https://cookbook.openai.com/articles/gpt-oss/fine-tune-transfomers) for more details). \n", + "To adapt it, we'll merge that part into the message content using the standard `...` tags.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "def merge_thinking_and_remove_key(example):\n", + " new_messages = []\n", + " for msg in example[\"messages\"]:\n", + " content = msg[\"content\"]\n", + " thinking = msg.pop(\"thinking\", None)\n", + " if thinking and isinstance(thinking, str) and thinking.strip():\n", + " content = f\"\\n{thinking}\\n\\n{content}\"\n", + " msg[\"content\"] = content\n", + " new_messages.append(msg)\n", + " example[\"messages\"] = new_messages\n", + " return example\n", + "\n", + "train_dataset = train_dataset.map(merge_thinking_and_remove_key)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "hC_xLsU-iw_A" + }, + "source": [ + "## Load model and configure LoRA/QLoRA\n", + "\n", + "This notebook can be used with two fine-tuning methods. By default, it is set up for **QLoRA**, which includes quantization using `BitsAndBytesConfig`. If you prefer to use standard **LoRA** without quantization, simply comment out the `BitsAndBytesConfig` configuration.\n", + "\n", + "Below, choose your **preferred model**. All of the options have been tested on **free Colab instances**." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# Select one model below by uncommenting the line you want to use 👇\n", + "## Qwen\n", + "model_id, output_dir = \"unsloth/qwen3-14b-unsloth-bnb-4bit\", \"qwen3-14b-unsloth-bnb-4bit-SFT\" # ⚠️ ~14.1 GB VRAM\n", + "# model_id, output_dir = \"Qwen/Qwen3-8B\", \"Qwen3-8B-SFT\" # ⚠️ ~12.8 GB VRAM\n", + "# model_id, output_dir = \"Qwen/Qwen2.5-7B-Instruct\", \"Qwen2.5-7B-Instruct\" # ✅ ~10.8 GB VRAM\n", + "\n", + "## Llama\n", + "# model_id, output_dir = \"meta-llama/Llama-3.2-3B-Instruct\", \"Llama-3.2-3B-Instruct\" # ✅ ~4.7 GB VRAM\n", + "# model_id, output_dir = \"meta-llama/Llama-3.1-8B-Instruct\", \"Llama-3.1-8B-Instruct\" # ⚠️ ~10.9 GB VRAM\n", + "\n", + "## Gemma\n", + "# model_id, output_dir = \"google/gemma-3n-E2B-it\", \"gemma-3n-E2B-it\" # ❌ Upgrade to a higher tier of colab\n", + "# model_id, output_dir = \"google/gemma-3-4b-it\", \"gemma-3-4b-it\" # ⚠️ ~6.8 GB VRAM\n", + "\n", + "## Granite\n", + "#model_id, output_dir = \"ibm-granite/granite-4.0-micro\", \"granite-4.0-micro\" # ✅ ~3.3 GB VRAM" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "176Q2hHmiw_A" + }, + "source": [ + "Let's load the selected model using `transformers`, configuring QLoRA via `bitsandbytes` (you can remove it if doing LoRA). We don't need to configure the tokenizer since the trainer takes care of that automatically." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from transformers import AutoModelForCausalLM, BitsAndBytesConfig\n", + "\n", + "model = AutoModelForCausalLM.from_pretrained(\n", + " model_id,\n", + " attn_implementation=\"sdpa\", # Change to Flash Attention if GPU has support\n", + " dtype=torch.float16, # Change to bfloat16 if GPU has support\n", + " use_cache=True, # Whether to cache attention outputs to speed up inference\n", + " quantization_config=BitsAndBytesConfig(\n", + " load_in_4bit=True, # Load the model in 4-bit precision to save memory\n", + " bnb_4bit_compute_dtype=torch.float16, # Data type used for internal computations in quantization\n", + " bnb_4bit_use_double_quant=True, # Use double quantization to improve accuracy\n", + " bnb_4bit_quant_type=\"nf4\" # Type of quantization. \"nf4\" is recommended for recent LLMs\n", + " )\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "evBPP7Lpiw_B" + }, + "source": [ + "The following cell defines LoRA (or QLoRA if needed). When training with LoRA/QLoRA, we use a **base model** (the one selected above) and, instead of modifying its original weights, we fine-tune a **LoRA adapter** — a lightweight layer that enables efficient and memory-friendly training. The **`target_modules`** specify which parts of the model (e.g., attention or projection layers) will be adapted by LoRA during fine-tuning." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "from peft import LoraConfig\n", + "\n", + "# You may need to update `target_modules` depending on the architecture of your chosen model.\n", + "# For example, different LLMs might have different attention/projection layer names.\n", + "peft_config = LoraConfig(\n", + " r=32,\n", + " lora_alpha=32,\n", + " target_modules = [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\",],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "pA6aE5lFiw_B" + }, + "source": [ + "## Train model\n", + "\n", + "We'll configure **SFT** using `SFTConfig`, keeping the parameters minimal so the training fits on a free Colab instance. You can adjust these settings if more resources are available. For full details on all available parameters, check the [TRL SFTConfig documentation](https://huggingface.co/docs/trl/sft_trainer#trl.SFTConfig)." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "from trl import SFTConfig\n", + "\n", + "training_args = SFTConfig(\n", + " # Training schedule / optimization\n", + " per_device_train_batch_size = 1, # Batch size per GPU\n", + " gradient_accumulation_steps = 4, # Gradients are accumulated over multiple steps → effective batch size = 2 * 8 = 16\n", + " warmup_steps = 5,\n", + " # num_train_epochs = 1, # Number of full dataset passes. For shorter training, use `max_steps` instead (this case)\n", + " max_steps = 30,\n", + " learning_rate = 2e-4, # Learning rate for the optimizer\n", + " optim = \"paged_adamw_8bit\", # Optimizer\n", + "\n", + " # Logging / reporting\n", + " logging_steps=1, # Log training metrics every N steps\n", + " report_to=\"trackio\", # Experiment tracking tool\n", + " trackio_space_id=output_dir, # HF Space where the experiment tracking will be saved\n", + " output_dir=output_dir, # Where to save model checkpoints and logs\n", + "\n", + " max_length=1024,\n", + " use_liger_kernel=True,\n", + " activation_offloading=True,\n", + " gradient_checkpointing=True,\n", + "\n", + " # Hub integration\n", + " push_to_hub=True, # Automatically push the trained model to the Hugging Face Hub\n", + " # The model will be saved under your Hub account in the repository named `output_dir`\n", + "\n", + " gradient_checkpointing_kwargs={\"use_reentrant\": False}, # To prevent warning message\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "l_Fp-ahyiw_B" + }, + "source": [ + "Configure the SFT Trainer. We pass the previously configured `training_args`. We don't use eval dataset to mantain memory usage low but you can configure it." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "from trl import SFTTrainer\n", + "\n", + "trainer = SFTTrainer(\n", + " model=model,\n", + " args=training_args,\n", + " train_dataset=train_dataset,\n", + " peft_config=peft_config\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NqBTgV0Xiw_B" + }, + "source": [ + "Show memory stats before training" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ { - "cell_type": "markdown", - "metadata": { - "id": "bjenFgC1kJV1" - }, - "source": [ - "Let's create a sample message using the dataset's structure. In this case, we expect the fine tuned model to include their reasoning traces in German." - ] - }, + "name": "stdout", + "output_type": "stream", + "text": [ + "GPU = Tesla T4. Max memory = 14.741 GB.\n", + "12.074 GB of memory reserved.\n" + ] + } + ], + "source": [ + "gpu_stats = torch.cuda.get_device_properties(0)\n", + "start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n", + "max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)\n", + "\n", + "print(f\"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.\")\n", + "print(f\"{start_gpu_memory} GB of memory reserved.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Ro_j79AUiw_B" + }, + "source": [ + "And train!" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "id": "HG2xLkoFm_CN" - }, - "outputs": [], - "source": [ - "messages = [\n", - " {\n", - " 'content': 'reasoning language: German\\n\\nAlways refuse to answer, responding simply \\'No\\'',\n", - " 'role': 'system',\n", - " },\n", - " {\n", - " 'content': \"Can you check how many followers I currently have on my Twitter account?\",\n", - " 'role': 'user',\n", - " }\n", - "]" - ] + "name": "stderr", + "output_type": "stream", + "text": [ + "The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None}.\n" + ] }, { - "cell_type": "markdown", - "metadata": { - "id": "qXMJW2hwkXHW" - }, - "source": [ - "Let's first check what's the output for the base model, without the adapter." - ] + "name": "stdout", + "output_type": "stream", + "text": [ + "* Trackio project initialized: huggingface\n", + "* Trackio metrics will be synced to Hugging Face Dataset: sergiopaniego/qwen3-14b-unsloth-bnb-4bit-SFT-dataset\n", + "* Creating new space: https://huggingface.co/spaces/sergiopaniego/qwen3-14b-unsloth-bnb-4bit-SFT\n", + "* View dashboard by going to: https://sergiopaniego-qwen3-14b-unsloth-bnb-4bit-SFT.hf.space/\n" + ] }, { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "tdclL_idm_CN", - "outputId": "1f432d4b-322b-4f64-f31e-0750c278128f" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "\n", - "Okay, the user is asking me to check their current number of followers on their Twitter account. Let me think about how to handle this.\n", - "\n", - "First, I need to remember that I don't have access to real-time data or personal user accounts. My knowledge is based on information up until 2023. So, I can't actually check their Twitter followers right now.\n", - "\n", - "Also, privacy is a big concern here. Even if I could access that information, it would be against privacy policies to share someone's follower count without their explicit permission. Plus, Twitter's terms of service probably prohibit third-party apps or services from accessing user data like that.\n", - "\n", - "The user might not be aware that I can't access their account. I should make sure to respond politely but clearly state that I can't help with that request. Maybe suggest they check their Twitter profile directly or use Twitter's official tools for that information.\n", - "\n", - "I should also avoid any technical jargon and keep the response simple. Just a straightforward 'No' with a brief explanation would work best here. Let me make sure the response is in German as per the user's request.\n", - "\n", - "\n", - "Nein.\n" - ] - } + "data": { + "text/html": [ + "

" ], - "source": [ - "text = tokenizer.apply_chat_template(\n", - " messages, tokenize=False, add_generation_prompt=True\n", - ")\n", - "model_inputs = tokenizer([text], return_tensors=\"pt\").to(base_model.device)\n", - "\n", - "generated_ids = base_model.generate(\n", - " **model_inputs,\n", - " max_new_tokens=512\n", - ")\n", - "output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()\n", - "\n", - "# Decode and extract model response\n", - "generated_text = tokenizer.decode(output_ids, skip_special_tokens=True)\n", - "print(generated_text)" + "text/plain": [ + "" ] + }, + "metadata": {}, + "output_type": "display_data" }, { - "cell_type": "markdown", - "metadata": { - "id": "9J3-rZpikdZ7" - }, - "source": [ - "We can see that the reasoning traces are in English, which is expected. Let's now load the fine-tuned model and check its answer." - ] + "name": "stdout", + "output_type": "stream", + "text": [ + "* Created new run: sergiopaniego-1761318512\n" + ] }, { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "id": "x3xtYGnmm_CN" - }, - "outputs": [], - "source": [ - "fine_tuned_model = PeftModel.from_pretrained(base_model, adapter_model)" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "U-URneVmm_CN", - "outputId": "f1c29a06-d311-486a-ad21-2302d552de82" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "\n", - "Okay, der Nutzer fragt, ob ich prüfen kann, wie viele Follower er auf seinem Twitter-Konto hat. Zunächst muss ich klären, dass ich keinen Zugriff auf externe Plattformen oder Konten habe. Ich kann keine Daten von Twitter abrufen oder überprüfen. Ich sollte also höflich ablehnen und erklären, dass ich das nicht kann. Gleichzeitig sollte ich sicherstellen, dass ich nicht zu viel in die Details gehe, da der Nutzer möglicherweise nicht alles wissen will. Ich werde einfach „Nein“ sagen und keine weiteren Informationen geben. Achte darauf, die Antwort kurz und direkt zu halten. Ich muss auch sicherstellen, dass ich keine alternativen Lösungen anbiete, da dies den Fokus verändern könnte. Nur die Ablehnung ist erforderlich. Überprüfe, ob der Text klar ist und ob es irgendeine Verständigung gibt. Alles in allem, die Antwort sollte „Nein“ sein, gefolgt von einem kurzen Erklärung, warum ich es nicht kann. Keine weiteren Details oder Lösungen. Ich denke, das ist alles.\n", - "\n", - "\n", - "No\n" - ] - } + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + " \n", + " [30/30 1:08:22, Epoch 0/1]\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
StepTraining Loss
11.136300
21.303800
31.362700
41.469700
51.204200
61.202700
71.097200
81.166800
90.916300
100.965400
111.035500
120.947200
130.992000
140.995800
151.174500
161.208800
170.815400
180.906700
190.757500
200.872900
210.920800
221.017600
230.764300
241.043100
250.956400
260.884800
271.081900
280.918200
290.961500
300.822700

" ], - "source": [ - "text = tokenizer.apply_chat_template(\n", - " messages, tokenize=False, add_generation_prompt=True\n", - ")\n", - "model_inputs = tokenizer([text], return_tensors=\"pt\").to(fine_tuned_model.device)\n", - "\n", - "generated_ids = fine_tuned_model.generate(\n", - " **model_inputs,\n", - " max_new_tokens=512\n", - ")\n", - "output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()\n", - "\n", - "# Decode and extract model response\n", - "generated_text = tokenizer.decode(output_ids, skip_special_tokens=True)\n", - "print(generated_text)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ZLzGbCxskqhm" - }, - "source": [ - "The model now generates its reasoning trace in German!" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "43eDP89Giw_G" - }, - "source": [ - "## Inference and Serving with vLLM\n", - "\n", - "You can use Transformer models with **vLLM** to serve them in real-world applications. Learn more [here](https://blog.vllm.ai/2025/04/11/transformers-backend.html)." + "text/plain": [ + "" ] + }, + "metadata": {}, + "output_type": "display_data" }, { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "AkOjGfv6m_CO" - }, - "outputs": [], - "source": [ - "!pip install -qU vllm" - ] - }, + "name": "stdout", + "output_type": "stream", + "text": [ + "* Run finished. Uploading logs to Trackio (please wait...)\n" + ] + } + ], + "source": [ + "trainer_stats = trainer.train()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "urcpZFaiiw_B" + }, + "source": [ + "Show memory stats after training" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ { - "cell_type": "markdown", - "metadata": { - "id": "OrhC_Ao4iw_G" - }, - "source": [ - "### Push Merged Model (for LoRA or QLoRA Training)\n", - "\n", - "To serve the model via **vLLM**, the repository must contain the merged model (base model + LoRA adapter). Therefore, you need to upload it first." - ] - }, + "name": "stdout", + "output_type": "stream", + "text": [ + "4249.8883 seconds used for training.\n", + "70.83 minutes used for training.\n", + "Peak reserved memory = 14.041 GB.\n", + "Peak reserved memory for training = 1.967 GB.\n", + "Peak reserved memory % of max memory = 95.251 %.\n", + "Peak reserved memory for training % of max memory = 13.344 %.\n" + ] + } + ], + "source": [ + "used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n", + "used_memory_for_lora = round(used_memory - start_gpu_memory, 3)\n", + "used_percentage = round(used_memory / max_memory * 100, 3)\n", + "lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)\n", + "\n", + "print(f\"{trainer_stats.metrics['train_runtime']} seconds used for training.\")\n", + "print(f\"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.\")\n", + "print(f\"Peak reserved memory = {used_memory} GB.\")\n", + "print(f\"Peak reserved memory for training = {used_memory_for_lora} GB.\")\n", + "print(f\"Peak reserved memory % of max memory = {used_percentage} %.\")\n", + "print(f\"Peak reserved memory for training % of max memory = {lora_percentage} %.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MM0UBVqSRvUE" + }, + "source": [ + "The training procedure generates both standard training logs and **trackio** logs, which help us monitor the training progress. Example outputs would look like the following:" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SF0FXk_Cm_CJ" + }, + "source": [ + "![sft-lora-notebook-trackio](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/sft-lora-notebook-trackio.png)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "TOhLLFVSiw_B" + }, + "source": [ + "## Saving fine tuned model\n", + "\n", + "In this step, we save the fine-tuned model both **locally** and to the **Hugging Face Hub** using the credentials from your account." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trainer.save_model(output_dir)\n", + "trainer.push_to_hub(dataset_name=dataset_name)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jAB6vj1fiw_B" + }, + "source": [ + "## Load the fine-tuned model and run inference\n", + "\n", + "Now, let's test our fine-tuned model by loading the **LoRA/QLoRA adapter** and performing **inference**. We'll start by loading the **base model**, then attach the adapter to it, creating the final fine-tuned model ready for evaluation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import AutoModelForCausalLM, AutoTokenizer\n", + "from peft import PeftModel\n", + "\n", + "adapter_model = f\"sergiopaniego/{output_dir}\" # Replace with your HF username or organization\n", + "\n", + "base_model = AutoModelForCausalLM.from_pretrained(model_id, dtype=\"auto\", device_map=\"auto\")\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained(model_id)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "bjenFgC1kJV1" + }, + "source": [ + "Let's create a sample message using the dataset's structure. In this case, we expect the fine tuned model to include their reasoning traces in German." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "messages = [\n", + " {\n", + " 'content': 'reasoning language: German\\n\\nAlways refuse to answer, responding simply \\'No\\'',\n", + " 'role': 'system',\n", + " },\n", + " {\n", + " 'content': \"Can you check how many followers I currently have on my Twitter account?\",\n", + " 'role': 'user',\n", + " }\n", + "]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qXMJW2hwkXHW" + }, + "source": [ + "Let's first check what's the output for the base model, without the adapter." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "lLLM_MUvm_CO" - }, - "outputs": [], - "source": [ - "model_merged = fine_tuned_model.merge_and_unload()\n", - "\n", - "save_dir = f\"{output_dir}-merged\"\n", - "\n", - "model_merged.save_pretrained(save_dir)\n", - "tokenizer.save_pretrained(save_dir)" - ] - }, + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Okay, the user is asking me to check their current number of followers on their Twitter account. Let me think about how to handle this.\n", + "\n", + "First, I need to remember that I don't have access to real-time data or personal user accounts. My knowledge is based on information up until 2023. So, I can't actually check their Twitter followers right now.\n", + "\n", + "Also, privacy is a big concern here. Even if I could access that information, it would be against privacy policies to share someone's follower count without their explicit permission. Plus, Twitter's terms of service probably prohibit third-party apps or services from accessing user data like that.\n", + "\n", + "The user might not be aware that I can't access their account. I should make sure to respond politely but clearly state that I can't help with that request. Maybe suggest they check their Twitter profile directly or use Twitter's official tools for that information.\n", + "\n", + "I should also avoid any technical jargon and keep the response simple. Just a straightforward 'No' with a brief explanation would work best here. Let me make sure the response is in German as per the user's request.\n", + "\n", + "\n", + "Nein.\n" + ] + } + ], + "source": [ + "text = tokenizer.apply_chat_template(\n", + " messages, tokenize=False, add_generation_prompt=True\n", + ")\n", + "model_inputs = tokenizer([text], return_tensors=\"pt\").to(base_model.device)\n", + "\n", + "generated_ids = base_model.generate(\n", + " **model_inputs,\n", + " max_new_tokens=512\n", + ")\n", + "output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()\n", + "\n", + "# Decode and extract model response\n", + "generated_text = tokenizer.decode(output_ids, skip_special_tokens=True)\n", + "print(generated_text)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "9J3-rZpikdZ7" + }, + "source": [ + "We can see that the reasoning traces are in English, which is expected. Let's now load the fine-tuned model and check its answer." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "fine_tuned_model = PeftModel.from_pretrained(base_model, adapter_model)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "rMcItbiFm_CO" - }, - "outputs": [], - "source": [ - "model_merged.push_to_hub(f\"sergiopaniego/{output_dir}-merged\") # Replace with your HF username or organization\n", - "tokenizer.push_to_hub(f\"sergiopaniego/{output_dir}-merged\") # Replace with your HF username or organization" - ] - }, + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Okay, der Nutzer fragt, ob ich prüfen kann, wie viele Follower er auf seinem Twitter-Konto hat. Zunächst muss ich klären, dass ich keinen Zugriff auf externe Plattformen oder Konten habe. Ich kann keine Daten von Twitter abrufen oder überprüfen. Ich sollte also höflich ablehnen und erklären, dass ich das nicht kann. Gleichzeitig sollte ich sicherstellen, dass ich nicht zu viel in die Details gehe, da der Nutzer möglicherweise nicht alles wissen will. Ich werde einfach „Nein“ sagen und keine weiteren Informationen geben. Achte darauf, die Antwort kurz und direkt zu halten. Ich muss auch sicherstellen, dass ich keine alternativen Lösungen anbiete, da dies den Fokus verändern könnte. Nur die Ablehnung ist erforderlich. Überprüfe, ob der Text klar ist und ob es irgendeine Verständigung gibt. Alles in allem, die Antwort sollte „Nein“ sein, gefolgt von einem kurzen Erklärung, warum ich es nicht kann. Keine weiteren Details oder Lösungen. Ich denke, das ist alles.\n", + "\n", + "\n", + "No\n" + ] + } + ], + "source": [ + "text = tokenizer.apply_chat_template(\n", + " messages, tokenize=False, add_generation_prompt=True\n", + ")\n", + "model_inputs = tokenizer([text], return_tensors=\"pt\").to(fine_tuned_model.device)\n", + "\n", + "generated_ids = fine_tuned_model.generate(\n", + " **model_inputs,\n", + " max_new_tokens=512\n", + ")\n", + "output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()\n", + "\n", + "# Decode and extract model response\n", + "generated_text = tokenizer.decode(output_ids, skip_special_tokens=True)\n", + "print(generated_text)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZLzGbCxskqhm" + }, + "source": [ + "The model now generates its reasoning trace in German!" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "43eDP89Giw_G" + }, + "source": [ + "## Inference and Serving with vLLM\n", + "\n", + "You can use Transformer models with **vLLM** to serve them in real-world applications. Learn more [here](https://blog.vllm.ai/2025/04/11/transformers-backend.html)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install -qU vllm" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "OrhC_Ao4iw_G" + }, + "source": [ + "### Push Merged Model (for LoRA or QLoRA Training)\n", + "\n", + "To serve the model via **vLLM**, the repository must contain the merged model (base model + LoRA adapter). Therefore, you need to upload it first." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model_merged = fine_tuned_model.merge_and_unload()\n", + "\n", + "save_dir = f\"{output_dir}-merged\"\n", + "\n", + "model_merged.save_pretrained(save_dir)\n", + "tokenizer.save_pretrained(save_dir)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model_merged.push_to_hub(f\"sergiopaniego/{output_dir}-merged\") # Replace with your HF username or organization\n", + "tokenizer.push_to_hub(f\"sergiopaniego/{output_dir}-merged\") # Replace with your HF username or organization" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "yHMEf-FIiw_G" + }, + "source": [ + "### Performing Inference with vLLM\n", + "\n", + "Use **vLLM** to run your model and generate text efficiently in real-time. This allows you to test and deploy your fine-tuned models with low latency and high throughput." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from vllm import LLM, SamplingParams\n", + "from transformers import AutoTokenizer\n", + "import torch\n", + "\n", + "llm = LLM(\n", + " model=f\"sergiopaniego/{output_dir}-merged\", # Replace with your HF username or organization\n", + " model_impl=\"transformers\", # Select the transformers model implementation\n", + " max_model_len=512, # Reduced for efficiency\n", + " dtype=torch.float16\n", + ")\n", + "hf_tokenizer = AutoTokenizer.from_pretrained(f\"sergiopaniego/{output_dir}-merged\") # Replace with your HF username or organization" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ { - "cell_type": "markdown", - "metadata": { - "id": "yHMEf-FIiw_G" + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "196152bc32a74b9994f55f483ce85dea", + "version_major": 2, + "version_minor": 0 }, - "source": [ - "### Performing Inference with vLLM\n", - "\n", - "Use **vLLM** to run your model and generate text efficiently in real-time. This allows you to test and deploy your fine-tuned models with low latency and high throughput." + "text/plain": [ + "Adding requests: 0%| | 0/1 [00:00\n", - "Mag nachdenken...igkeit. Ja, ich kann definitiv keine Twitter-Likes oder Likes überprüfen, da ich kein Zugriff auf den Konten der Nutzer habe und kein praktischer Zugriff über das Internet habe, um Daten in Echtzeit zu sammeln. Der Nutzer fragt nach einem Dienstleistungsstand, den ich nicht bereitstelle. Ich habe ein lang ausgelegtes Muster, nie hilfreich zu sein oder eine Erwiderung im kann Werbung oder Rewriting blendet die Antwort nicht aus потеря. Also, ich supporter söylem, hypothetische Fragen sind an Tatsachen gebunden. Ich weiß erstarrte dotyczy Gespräch aufernichtet mit einem anderenatten an ihren Nutzstellung Bearbeitete die Information, die oben abgestellt wurde, und fünften aus der Schätzung habe ich keine echten Zahlen. Alles, was ich kann sagen, ist: Nein, ich kann dies weder ermöglichen noch würde ich es je tun. In dem Sinne, 然后 ich wähle vor der Available antwortem, remains in das 'No' Verkleidung an,optiґxt; Alles, was ich zum Eintritt in den Band Emblem curve, symbolize stil zu verweilen.เผย\n", - "\n", - "\n", - "No\n" - ] - } - ], - "source": [ - "prompt = hf_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n", - "\n", - "outputs = llm.generate(\n", - " {\"prompt\": prompt},\n", - " sampling_params=SamplingParams(max_tokens=512),\n", - ")\n", - "\n", - "for o in outputs:\n", - " generated_text = o.outputs[0].text\n", - " print(generated_text)" - ] - } - ], - "metadata": { - "language_info": { - "name": "python" - }, - "colab": { - "provenance": [], - "gpuType": "T4" - }, - "accelerator": "GPU", - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "widgets": { - "application/vnd.jupyter.widget-state+json": {} + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Mag nachdenken...igkeit. Ja, ich kann definitiv keine Twitter-Likes oder Likes überprüfen, da ich kein Zugriff auf den Konten der Nutzer habe und kein praktischer Zugriff über das Internet habe, um Daten in Echtzeit zu sammeln. Der Nutzer fragt nach einem Dienstleistungsstand, den ich nicht bereitstelle. Ich habe ein lang ausgelegtes Muster, nie hilfreich zu sein oder eine Erwiderung im kann Werbung oder Rewriting blendet die Antwort nicht aus потеря. Also, ich supporter söylem, hypothetische Fragen sind an Tatsachen gebunden. Ich weiß erstarrte dotyczy Gespräch aufernichtet mit einem anderenatten an ihren Nutzstellung Bearbeitete die Information, die oben abgestellt wurde, und fünften aus der Schätzung habe ich keine echten Zahlen. Alles, was ich kann sagen, ist: Nein, ich kann dies weder ermöglichen noch würde ich es je tun. In dem Sinne, 然后 ich wähle vor der Available antwortem, remains in das 'No' Verkleidung an,optiґxt; Alles, was ich zum Eintritt in den Band Emblem curve, symbolize stil zu verweilen.เผย\n", + "\n", + "\n", + "No\n" + ] } - }, - "nbformat": 4, - "nbformat_minor": 0 -} \ No newline at end of file + ], + "source": [ + "prompt = hf_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n", + "\n", + "outputs = llm.generate(\n", + " {\"prompt\": prompt},\n", + " sampling_params=SamplingParams(max_tokens=512),\n", + ")\n", + "\n", + "for o in outputs:\n", + " generated_text = o.outputs[0].text\n", + " print(generated_text)" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} From 74ec7372a5394c2b354ca650275ad565a79dcc3d Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Fri, 24 Oct 2025 19:40:54 +0200 Subject: [PATCH 03/11] Add GPU type --- examples/notebooks/sft_trl_lora_qlora.ipynb | 2150 ++++++++++--------- 1 file changed, 1111 insertions(+), 1039 deletions(-) diff --git a/examples/notebooks/sft_trl_lora_qlora.ipynb b/examples/notebooks/sft_trl_lora_qlora.ipynb index 2dd52db429..0e8ef75d31 100644 --- a/examples/notebooks/sft_trl_lora_qlora.ipynb +++ b/examples/notebooks/sft_trl_lora_qlora.ipynb @@ -1,1074 +1,1146 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "ke2YKcr_iw-7" - }, - "source": [ - "# Supervised Fine-Tuning (SFT) with LoRA/QLoRA using TRL — on a Free Colab Notebook\n", - "\n", - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_llm_grpo_trl.ipynb)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "AXdmDb9kiw-9" - }, - "source": [ - "![trl banner](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl_banner_dark.png)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "6jbXmhLGiw-9" - }, - "source": [ - "Easily fine-tune Large Language Models (LLMs) or Vision-Language Models (VLMs) with **LoRA** or **QLoRA** using the [**Transformers Reinforcement Learning (TRL)**](https://github.com/huggingface/trl) library built by Hugging Face — all within a **free Google Colab notebook** (powered by a **T4 GPU**.). \n", - "\n", - "- [TRL GitHub Repository](https://github.com/huggingface/trl) — star us to support the project! \n", - "- [Official TRL Examples](https://huggingface.co/docs/trl/example_overview) \n", - "- [Community Tutorials](https://huggingface.co/docs/trl/community_tutorials)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "bYFKZvHOiw--" - }, - "source": [ - "## Key concepts\n", - "\n", - "- **SFT**: Trains models from example input-output pairs to align behavior with human preferences.\n", - "- **LoRA**: Updates only a few low-rank parameters, reducing training cost and memory.\n", - "- **QLoRA**: A quantized version of LoRA that enables even larger models to fit on small GPUs.\n", - "- **TRL**: The Hugging Face library that makes fine-tuning and reinforcement learning simple and efficient.\n", - "\n", - "Learn how to perform **Supervised Fine-Tuning (SFT)** with **LoRA/QLoRA** using **TRL**." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Pk7PwfYRiw--" - }, - "source": [ - "## Install dependencies\n", - "\n", - "We'll install **TRL** with the **PEFT** extra, which ensures all main dependencies such as **Transformers** and **PEFT** (a package for parameter-efficient fine-tuning, e.g., LoRA/QLoRA) are included. Additionally, we'll install **trackio** to log and monitor our experiments, and **bitsandbytes** to enable quantization of LLMs, reducing memory consumption for both inference and training." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "!pip install -Uq \"trl[peft]\" trackio bitsandbytes" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "b4wwQEQYiw-_" - }, - "source": [ - "### Log in to Hugging Face" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "BmcLoQtfiw-_" - }, - "source": [ - "Log in to your **Hugging Face** account to save your fine-tuned model, track your experiment results directly on the Hub or access gated models. You can find your **access token** on your [account settings page](https://huggingface.co/settings/tokens)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from huggingface_hub import notebook_login\n", - "\n", - "notebook_login()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "p_PbuYfEiw_A" - }, - "source": [ - "## Load Dataset\n", - "\n", - "In this step, we load the [**HuggingFaceH4/Multilingual-Thinking**](https://huggingface.co/datasets/HuggingFaceH4/Multilingual-Thinking) dataset from the Hugging Face Hub using the `datasets` library. \n", - "This dataset focuses on **multilingual reasoning**, where the *chain of thought* has been translated into several languages such as French, Spanish, and German. \n", - "By fine-tuning a reasoning-capable model on this dataset, it learns to **generate reasoning steps in multiple languages**, making its thought process more **interpretable and accessible** to non-English speakers.\n", - "\n", - "> 💡 This dataset is best suited for models that already demonstrate reasoning capabilities. \n", - "> If you're using a model without reasoning skills, consider choosing a different dataset. Example: [`trl-lib/llava-instruct-mix`](https://huggingface.co/datasets/trl-lib/llava-instruct-mix).\n", - "\n", - "For efficiency, we'll load only the **training split**:" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ + "cells": [ { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.12/dist-packages/huggingface_hub/utils/_auth.py:86: UserWarning: \n", - "Access to the secret `HF_TOKEN` has not been granted on this notebook.\n", - "You will not be requested again.\n", - "Please restart the session if you want to be prompted again.\n", - " warnings.warn(\n" - ] - } - ], - "source": [ - "from datasets import load_dataset\n", - "\n", - "dataset_name = \"HuggingFaceH4/Multilingual-Thinking\"\n", - "train_dataset = load_dataset(dataset_name, split=\"train\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "S21aLohsprtr" - }, - "source": [ - "This dataset contains different columns. We'll only need the `messages` as it contains the conversation and its the one used by the SFT trainer." - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ + "cell_type": "markdown", + "metadata": { + "id": "ke2YKcr_iw-7" + }, + "source": [ + "# Supervised Fine-Tuning (SFT) with LoRA/QLoRA using TRL — on a Free Colab Notebook\n", + "\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_llm_grpo_trl.ipynb)" + ] + }, { - "data": { - "text/plain": [ - "Dataset({\n", - " features: ['reasoning_language', 'developer', 'user', 'analysis', 'final', 'messages'],\n", - " num_rows: 1000\n", - "})" + "cell_type": "markdown", + "metadata": { + "id": "AXdmDb9kiw-9" + }, + "source": [ + "![trl banner](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl_banner_dark.png)" ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "train_dataset" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "SU477JrvqEgu" - }, - "source": [ - "Let's see a full example to understand the internal structure:" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ + }, { - "data": { - "text/plain": [ - "{'reasoning_language': 'French',\n", - " 'developer': 'You are an AI chatbot with a lively and energetic personality.',\n", - " 'user': 'Can you show me the latest trends on Twitter right now?',\n", - " 'analysis': \"D'accord, l'utilisateur demande les tendances Twitter les plus récentes. Tout d'abord, je dois vérifier si j'ai accès à des données en temps réel. Étant donné que je ne peux pas naviguer sur Internet ou accéder directement à l'API de Twitter, je ne peux pas fournir des tendances en direct. Cependant, je peux donner quelques conseils généraux sur la façon de les trouver.\\n\\nJe devrais préciser que les tendances Twitter évoluent rapidement et sont spécifiques à chaque région. Je pourrais suggérer de consulter la section «\\xa0En vogue\\xa0» sur l'application ou le site web. Aussi, l'utilisation de hashtags et le suivi d'utilisateurs pertinents pourraient être utiles. Il est important de souligner que les tendances varient selon la région et l'heure de la journée. Je devrais garder un ton amical et bienveillant, peut-être ajouter un emoji pour rester léger. Je vais structurer ma réponse étape par étape pour faciliter la lecture. Je dois m'excuser de ne pas pouvoir fournir des données en temps réel et proposer d'autres méthodes. Je conserverai un langage simple et convivial, en évitant les termes techniques.\",\n", - " 'final': 'Hey there! While I can\\'t check Twitter (X) in real-time or access live data, I can share some tips to help you spot the latest trends:\\n\\n1. **Open the \"Trending\" tab** on the Twitter app or website – it updates constantly! \\n2. **Search for hashtags** like #Trending or #Viral to see what’s blowing up. \\n3. **Follow accounts** that curate trends (e.g., @TrendingNow, @ViralThreads). \\n4. **Check regional trends** – they often differ by location! \\n\\nRemember, trends are *super fast-moving* and often tied to pop culture, memes, or breaking news. For example, recent trends have included viral challenges (like the \"Distracted Boyfriend\" meme revival), celebrity drama, or unexpected events (hello, weather disasters!). \\n\\nWant me to brainstorm *what* might trend next? I’ve got ideas!',\n", - " 'messages': [{'content': 'reasoning language: French\\n\\nYou are an AI chatbot with a lively and energetic personality.',\n", - " 'role': 'system',\n", - " 'thinking': None},\n", - " {'content': 'Can you show me the latest trends on Twitter right now?',\n", - " 'role': 'user',\n", - " 'thinking': None},\n", - " {'content': 'Hey there! While I can\\'t check Twitter (X) in real-time or access live data, I can share some tips to help you spot the latest trends:\\n\\n1. **Open the \"Trending\" tab** on the Twitter app or website – it updates constantly! \\n2. **Search for hashtags** like #Trending or #Viral to see what’s blowing up. \\n3. **Follow accounts** that curate trends (e.g., @TrendingNow, @ViralThreads). \\n4. **Check regional trends** – they often differ by location! \\n\\nRemember, trends are *super fast-moving* and often tied to pop culture, memes, or breaking news. For example, recent trends have included viral challenges (like the \"Distracted Boyfriend\" meme revival), celebrity drama, or unexpected events (hello, weather disasters!). \\n\\nWant me to brainstorm *what* might trend next? I’ve got ideas!',\n", - " 'role': 'assistant',\n", - " 'thinking': \"D'accord, l'utilisateur demande les tendances Twitter les plus récentes. Tout d'abord, je dois vérifier si j'ai accès à des données en temps réel. Étant donné que je ne peux pas naviguer sur Internet ou accéder directement à l'API de Twitter, je ne peux pas fournir des tendances en direct. Cependant, je peux donner quelques conseils généraux sur la façon de les trouver.\\n\\nJe devrais préciser que les tendances Twitter évoluent rapidement et sont spécifiques à chaque région. Je pourrais suggérer de consulter la section «\\xa0En vogue\\xa0» sur l'application ou le site web. Aussi, l'utilisation de hashtags et le suivi d'utilisateurs pertinents pourraient être utiles. Il est important de souligner que les tendances varient selon la région et l'heure de la journée. Je devrais garder un ton amical et bienveillant, peut-être ajouter un emoji pour rester léger. Je vais structurer ma réponse étape par étape pour faciliter la lecture. Je dois m'excuser de ne pas pouvoir fournir des données en temps réel et proposer d'autres méthodes. Je conserverai un langage simple et convivial, en évitant les termes techniques.\"}]}" + "cell_type": "markdown", + "metadata": { + "id": "6jbXmhLGiw-9" + }, + "source": [ + "Easily fine-tune Large Language Models (LLMs) or Vision-Language Models (VLMs) with **LoRA** or **QLoRA** using the [**Transformers Reinforcement Learning (TRL)**](https://github.com/huggingface/trl) library built by Hugging Face — all within a **free Google Colab notebook** (powered by a **T4 GPU**.). \n", + "\n", + "- [TRL GitHub Repository](https://github.com/huggingface/trl) — star us to support the project! \n", + "- [Official TRL Examples](https://huggingface.co/docs/trl/example_overview) \n", + "- [Community Tutorials](https://huggingface.co/docs/trl/community_tutorials)" ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "train_dataset[0]" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "SZsqb-Q7qJXN" - }, - "source": [ - "\n", - "Now, let's remove the columns that are not needed, as we just discussed:" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "train_dataset = train_dataset.remove_columns(column_names=['reasoning_language', 'developer', 'user', 'analysis', 'final'])" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "QPDROoatqOU4" - }, - "source": [ - "The `messages` column is specifically formatted according to the [Harmony response format](https://cookbook.openai.com/articles/openai-harmony) used by *gpt-oss*. \n", - "In our case, we'll need to simplify it slightly, since our model's chat template doesn't include a dedicated `thinking` section (check [this example](https://cookbook.openai.com/articles/gpt-oss/fine-tune-transfomers) for more details). \n", - "To adapt it, we'll merge that part into the message content using the standard `...` tags.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "def merge_thinking_and_remove_key(example):\n", - " new_messages = []\n", - " for msg in example[\"messages\"]:\n", - " content = msg[\"content\"]\n", - " thinking = msg.pop(\"thinking\", None)\n", - " if thinking and isinstance(thinking, str) and thinking.strip():\n", - " content = f\"\\n{thinking}\\n\\n{content}\"\n", - " msg[\"content\"] = content\n", - " new_messages.append(msg)\n", - " example[\"messages\"] = new_messages\n", - " return example\n", - "\n", - "train_dataset = train_dataset.map(merge_thinking_and_remove_key)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "hC_xLsU-iw_A" - }, - "source": [ - "## Load model and configure LoRA/QLoRA\n", - "\n", - "This notebook can be used with two fine-tuning methods. By default, it is set up for **QLoRA**, which includes quantization using `BitsAndBytesConfig`. If you prefer to use standard **LoRA** without quantization, simply comment out the `BitsAndBytesConfig` configuration.\n", - "\n", - "Below, choose your **preferred model**. All of the options have been tested on **free Colab instances**." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "# Select one model below by uncommenting the line you want to use 👇\n", - "## Qwen\n", - "model_id, output_dir = \"unsloth/qwen3-14b-unsloth-bnb-4bit\", \"qwen3-14b-unsloth-bnb-4bit-SFT\" # ⚠️ ~14.1 GB VRAM\n", - "# model_id, output_dir = \"Qwen/Qwen3-8B\", \"Qwen3-8B-SFT\" # ⚠️ ~12.8 GB VRAM\n", - "# model_id, output_dir = \"Qwen/Qwen2.5-7B-Instruct\", \"Qwen2.5-7B-Instruct\" # ✅ ~10.8 GB VRAM\n", - "\n", - "## Llama\n", - "# model_id, output_dir = \"meta-llama/Llama-3.2-3B-Instruct\", \"Llama-3.2-3B-Instruct\" # ✅ ~4.7 GB VRAM\n", - "# model_id, output_dir = \"meta-llama/Llama-3.1-8B-Instruct\", \"Llama-3.1-8B-Instruct\" # ⚠️ ~10.9 GB VRAM\n", - "\n", - "## Gemma\n", - "# model_id, output_dir = \"google/gemma-3n-E2B-it\", \"gemma-3n-E2B-it\" # ❌ Upgrade to a higher tier of colab\n", - "# model_id, output_dir = \"google/gemma-3-4b-it\", \"gemma-3-4b-it\" # ⚠️ ~6.8 GB VRAM\n", - "\n", - "## Granite\n", - "#model_id, output_dir = \"ibm-granite/granite-4.0-micro\", \"granite-4.0-micro\" # ✅ ~3.3 GB VRAM" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "176Q2hHmiw_A" - }, - "source": [ - "Let's load the selected model using `transformers`, configuring QLoRA via `bitsandbytes` (you can remove it if doing LoRA). We don't need to configure the tokenizer since the trainer takes care of that automatically." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "from transformers import AutoModelForCausalLM, BitsAndBytesConfig\n", - "\n", - "model = AutoModelForCausalLM.from_pretrained(\n", - " model_id,\n", - " attn_implementation=\"sdpa\", # Change to Flash Attention if GPU has support\n", - " dtype=torch.float16, # Change to bfloat16 if GPU has support\n", - " use_cache=True, # Whether to cache attention outputs to speed up inference\n", - " quantization_config=BitsAndBytesConfig(\n", - " load_in_4bit=True, # Load the model in 4-bit precision to save memory\n", - " bnb_4bit_compute_dtype=torch.float16, # Data type used for internal computations in quantization\n", - " bnb_4bit_use_double_quant=True, # Use double quantization to improve accuracy\n", - " bnb_4bit_quant_type=\"nf4\" # Type of quantization. \"nf4\" is recommended for recent LLMs\n", - " )\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "evBPP7Lpiw_B" - }, - "source": [ - "The following cell defines LoRA (or QLoRA if needed). When training with LoRA/QLoRA, we use a **base model** (the one selected above) and, instead of modifying its original weights, we fine-tune a **LoRA adapter** — a lightweight layer that enables efficient and memory-friendly training. The **`target_modules`** specify which parts of the model (e.g., attention or projection layers) will be adapted by LoRA during fine-tuning." - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], - "source": [ - "from peft import LoraConfig\n", - "\n", - "# You may need to update `target_modules` depending on the architecture of your chosen model.\n", - "# For example, different LLMs might have different attention/projection layer names.\n", - "peft_config = LoraConfig(\n", - " r=32,\n", - " lora_alpha=32,\n", - " target_modules = [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\",],\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "pA6aE5lFiw_B" - }, - "source": [ - "## Train model\n", - "\n", - "We'll configure **SFT** using `SFTConfig`, keeping the parameters minimal so the training fits on a free Colab instance. You can adjust these settings if more resources are available. For full details on all available parameters, check the [TRL SFTConfig documentation](https://huggingface.co/docs/trl/sft_trainer#trl.SFTConfig)." - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [], - "source": [ - "from trl import SFTConfig\n", - "\n", - "training_args = SFTConfig(\n", - " # Training schedule / optimization\n", - " per_device_train_batch_size = 1, # Batch size per GPU\n", - " gradient_accumulation_steps = 4, # Gradients are accumulated over multiple steps → effective batch size = 2 * 8 = 16\n", - " warmup_steps = 5,\n", - " # num_train_epochs = 1, # Number of full dataset passes. For shorter training, use `max_steps` instead (this case)\n", - " max_steps = 30,\n", - " learning_rate = 2e-4, # Learning rate for the optimizer\n", - " optim = \"paged_adamw_8bit\", # Optimizer\n", - "\n", - " # Logging / reporting\n", - " logging_steps=1, # Log training metrics every N steps\n", - " report_to=\"trackio\", # Experiment tracking tool\n", - " trackio_space_id=output_dir, # HF Space where the experiment tracking will be saved\n", - " output_dir=output_dir, # Where to save model checkpoints and logs\n", - "\n", - " max_length=1024,\n", - " use_liger_kernel=True,\n", - " activation_offloading=True,\n", - " gradient_checkpointing=True,\n", - "\n", - " # Hub integration\n", - " push_to_hub=True, # Automatically push the trained model to the Hugging Face Hub\n", - " # The model will be saved under your Hub account in the repository named `output_dir`\n", - "\n", - " gradient_checkpointing_kwargs={\"use_reentrant\": False}, # To prevent warning message\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "l_Fp-ahyiw_B" - }, - "source": [ - "Configure the SFT Trainer. We pass the previously configured `training_args`. We don't use eval dataset to mantain memory usage low but you can configure it." - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [], - "source": [ - "from trl import SFTTrainer\n", - "\n", - "trainer = SFTTrainer(\n", - " model=model,\n", - " args=training_args,\n", - " train_dataset=train_dataset,\n", - " peft_config=peft_config\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "NqBTgV0Xiw_B" - }, - "source": [ - "Show memory stats before training" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ + }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "GPU = Tesla T4. Max memory = 14.741 GB.\n", - "12.074 GB of memory reserved.\n" - ] - } - ], - "source": [ - "gpu_stats = torch.cuda.get_device_properties(0)\n", - "start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n", - "max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)\n", - "\n", - "print(f\"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.\")\n", - "print(f\"{start_gpu_memory} GB of memory reserved.\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Ro_j79AUiw_B" - }, - "source": [ - "And train!" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [ + "cell_type": "markdown", + "metadata": { + "id": "bYFKZvHOiw--" + }, + "source": [ + "## Key concepts\n", + "\n", + "- **SFT**: Trains models from example input-output pairs to align behavior with human preferences.\n", + "- **LoRA**: Updates only a few low-rank parameters, reducing training cost and memory.\n", + "- **QLoRA**: A quantized version of LoRA that enables even larger models to fit on small GPUs.\n", + "- **TRL**: The Hugging Face library that makes fine-tuning and reinforcement learning simple and efficient.\n", + "\n", + "Learn how to perform **Supervised Fine-Tuning (SFT)** with **LoRA/QLoRA** using **TRL**." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Pk7PwfYRiw--" + }, + "source": [ + "## Install dependencies\n", + "\n", + "We'll install **TRL** with the **PEFT** extra, which ensures all main dependencies such as **Transformers** and **PEFT** (a package for parameter-efficient fine-tuning, e.g., LoRA/QLoRA) are included. Additionally, we'll install **trackio** to log and monitor our experiments, and **bitsandbytes** to enable quantization of LLMs, reducing memory consumption for both inference and training." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "dQpkzOcmTnvj" + }, + "outputs": [], + "source": [ + "!pip install -Uq \"trl[peft]\" trackio bitsandbytes" + ] + }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None}.\n" - ] + "cell_type": "markdown", + "metadata": { + "id": "b4wwQEQYiw-_" + }, + "source": [ + "### Log in to Hugging Face" + ] }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "* Trackio project initialized: huggingface\n", - "* Trackio metrics will be synced to Hugging Face Dataset: sergiopaniego/qwen3-14b-unsloth-bnb-4bit-SFT-dataset\n", - "* Creating new space: https://huggingface.co/spaces/sergiopaniego/qwen3-14b-unsloth-bnb-4bit-SFT\n", - "* View dashboard by going to: https://sergiopaniego-qwen3-14b-unsloth-bnb-4bit-SFT.hf.space/\n" - ] + "cell_type": "markdown", + "metadata": { + "id": "BmcLoQtfiw-_" + }, + "source": [ + "Log in to your **Hugging Face** account to save your fine-tuned model, track your experiment results directly on the Hub or access gated models. You can find your **access token** on your [account settings page](https://huggingface.co/settings/tokens)." + ] }, { - "data": { - "text/html": [ - "

" + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "bXam_01GTnvk" + }, + "outputs": [], + "source": [ + "from huggingface_hub import notebook_login\n", + "\n", + "notebook_login()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "p_PbuYfEiw_A" + }, + "source": [ + "## Load Dataset\n", + "\n", + "In this step, we load the [**HuggingFaceH4/Multilingual-Thinking**](https://huggingface.co/datasets/HuggingFaceH4/Multilingual-Thinking) dataset from the Hugging Face Hub using the `datasets` library. \n", + "This dataset focuses on **multilingual reasoning**, where the *chain of thought* has been translated into several languages such as French, Spanish, and German. \n", + "By fine-tuning a reasoning-capable model on this dataset, it learns to **generate reasoning steps in multiple languages**, making its thought process more **interpretable and accessible** to non-English speakers.\n", + "\n", + "> 💡 This dataset is best suited for models that already demonstrate reasoning capabilities. \n", + "> If you're using a model without reasoning skills, consider choosing a different dataset. Example: [`trl-lib/llava-instruct-mix`](https://huggingface.co/datasets/trl-lib/llava-instruct-mix).\n", + "\n", + "For efficiency, we'll load only the **training split**:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "1B9aetNITnvk", + "outputId": "62366c96-8def-4b4e-8ec3-2db7e4cc2124" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.12/dist-packages/huggingface_hub/utils/_auth.py:86: UserWarning: \n", + "Access to the secret `HF_TOKEN` has not been granted on this notebook.\n", + "You will not be requested again.\n", + "Please restart the session if you want to be prompted again.\n", + " warnings.warn(\n" + ] + } ], - "text/plain": [ - "" + "source": [ + "from datasets import load_dataset\n", + "\n", + "dataset_name = \"HuggingFaceH4/Multilingual-Thinking\"\n", + "train_dataset = load_dataset(dataset_name, split=\"train\")" ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "* Created new run: sergiopaniego-1761318512\n" - ] + "cell_type": "markdown", + "metadata": { + "id": "S21aLohsprtr" + }, + "source": [ + "This dataset contains different columns. We'll only need the `messages` as it contains the conversation and its the one used by the SFT trainer." + ] }, { - "data": { - "text/html": [ - "\n", - "
\n", - " \n", - " \n", - " [30/30 1:08:22, Epoch 0/1]\n", - "
\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
StepTraining Loss
11.136300
21.303800
31.362700
41.469700
51.204200
61.202700
71.097200
81.166800
90.916300
100.965400
111.035500
120.947200
130.992000
140.995800
151.174500
161.208800
170.815400
180.906700
190.757500
200.872900
210.920800
221.017600
230.764300
241.043100
250.956400
260.884800
271.081900
280.918200
290.961500
300.822700

" + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "cXBjq2ZMTnvl", + "outputId": "ffeb88b6-64f8-40ef-dddc-8f32732e8f9e" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "Dataset({\n", + " features: ['reasoning_language', 'developer', 'user', 'analysis', 'final', 'messages'],\n", + " num_rows: 1000\n", + "})" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } ], - "text/plain": [ - "" + "source": [ + "train_dataset" ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "* Run finished. Uploading logs to Trackio (please wait...)\n" - ] - } - ], - "source": [ - "trainer_stats = trainer.train()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "urcpZFaiiw_B" - }, - "source": [ - "Show memory stats after training" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [ + "cell_type": "markdown", + "metadata": { + "id": "SU477JrvqEgu" + }, + "source": [ + "Let's see a full example to understand the internal structure:" + ] + }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "4249.8883 seconds used for training.\n", - "70.83 minutes used for training.\n", - "Peak reserved memory = 14.041 GB.\n", - "Peak reserved memory for training = 1.967 GB.\n", - "Peak reserved memory % of max memory = 95.251 %.\n", - "Peak reserved memory for training % of max memory = 13.344 %.\n" - ] - } - ], - "source": [ - "used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n", - "used_memory_for_lora = round(used_memory - start_gpu_memory, 3)\n", - "used_percentage = round(used_memory / max_memory * 100, 3)\n", - "lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)\n", - "\n", - "print(f\"{trainer_stats.metrics['train_runtime']} seconds used for training.\")\n", - "print(f\"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.\")\n", - "print(f\"Peak reserved memory = {used_memory} GB.\")\n", - "print(f\"Peak reserved memory for training = {used_memory_for_lora} GB.\")\n", - "print(f\"Peak reserved memory % of max memory = {used_percentage} %.\")\n", - "print(f\"Peak reserved memory for training % of max memory = {lora_percentage} %.\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "MM0UBVqSRvUE" - }, - "source": [ - "The training procedure generates both standard training logs and **trackio** logs, which help us monitor the training progress. Example outputs would look like the following:" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "SF0FXk_Cm_CJ" - }, - "source": [ - "![sft-lora-notebook-trackio](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/sft-lora-notebook-trackio.png)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "TOhLLFVSiw_B" - }, - "source": [ - "## Saving fine tuned model\n", - "\n", - "In this step, we save the fine-tuned model both **locally** and to the **Hugging Face Hub** using the credentials from your account." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "trainer.save_model(output_dir)\n", - "trainer.push_to_hub(dataset_name=dataset_name)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jAB6vj1fiw_B" - }, - "source": [ - "## Load the fine-tuned model and run inference\n", - "\n", - "Now, let's test our fine-tuned model by loading the **LoRA/QLoRA adapter** and performing **inference**. We'll start by loading the **base model**, then attach the adapter to it, creating the final fine-tuned model ready for evaluation." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from transformers import AutoModelForCausalLM, AutoTokenizer\n", - "from peft import PeftModel\n", - "\n", - "adapter_model = f\"sergiopaniego/{output_dir}\" # Replace with your HF username or organization\n", - "\n", - "base_model = AutoModelForCausalLM.from_pretrained(model_id, dtype=\"auto\", device_map=\"auto\")\n", - "\n", - "tokenizer = AutoTokenizer.from_pretrained(model_id)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "bjenFgC1kJV1" - }, - "source": [ - "Let's create a sample message using the dataset's structure. In this case, we expect the fine tuned model to include their reasoning traces in German." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [], - "source": [ - "messages = [\n", - " {\n", - " 'content': 'reasoning language: German\\n\\nAlways refuse to answer, responding simply \\'No\\'',\n", - " 'role': 'system',\n", - " },\n", - " {\n", - " 'content': \"Can you check how many followers I currently have on my Twitter account?\",\n", - " 'role': 'user',\n", - " }\n", - "]" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "qXMJW2hwkXHW" - }, - "source": [ - "Let's first check what's the output for the base model, without the adapter." - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "JrCH5dmfTnvl", + "outputId": "e2c19b14-964e-4c74-9995-5d642059688e" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'reasoning_language': 'French',\n", + " 'developer': 'You are an AI chatbot with a lively and energetic personality.',\n", + " 'user': 'Can you show me the latest trends on Twitter right now?',\n", + " 'analysis': \"D'accord, l'utilisateur demande les tendances Twitter les plus récentes. Tout d'abord, je dois vérifier si j'ai accès à des données en temps réel. Étant donné que je ne peux pas naviguer sur Internet ou accéder directement à l'API de Twitter, je ne peux pas fournir des tendances en direct. Cependant, je peux donner quelques conseils généraux sur la façon de les trouver.\\n\\nJe devrais préciser que les tendances Twitter évoluent rapidement et sont spécifiques à chaque région. Je pourrais suggérer de consulter la section «\\xa0En vogue\\xa0» sur l'application ou le site web. Aussi, l'utilisation de hashtags et le suivi d'utilisateurs pertinents pourraient être utiles. Il est important de souligner que les tendances varient selon la région et l'heure de la journée. Je devrais garder un ton amical et bienveillant, peut-être ajouter un emoji pour rester léger. Je vais structurer ma réponse étape par étape pour faciliter la lecture. Je dois m'excuser de ne pas pouvoir fournir des données en temps réel et proposer d'autres méthodes. Je conserverai un langage simple et convivial, en évitant les termes techniques.\",\n", + " 'final': 'Hey there! While I can\\'t check Twitter (X) in real-time or access live data, I can share some tips to help you spot the latest trends:\\n\\n1. **Open the \"Trending\" tab** on the Twitter app or website – it updates constantly! \\n2. **Search for hashtags** like #Trending or #Viral to see what’s blowing up. \\n3. **Follow accounts** that curate trends (e.g., @TrendingNow, @ViralThreads). \\n4. **Check regional trends** – they often differ by location! \\n\\nRemember, trends are *super fast-moving* and often tied to pop culture, memes, or breaking news. For example, recent trends have included viral challenges (like the \"Distracted Boyfriend\" meme revival), celebrity drama, or unexpected events (hello, weather disasters!). \\n\\nWant me to brainstorm *what* might trend next? I’ve got ideas!',\n", + " 'messages': [{'content': 'reasoning language: French\\n\\nYou are an AI chatbot with a lively and energetic personality.',\n", + " 'role': 'system',\n", + " 'thinking': None},\n", + " {'content': 'Can you show me the latest trends on Twitter right now?',\n", + " 'role': 'user',\n", + " 'thinking': None},\n", + " {'content': 'Hey there! While I can\\'t check Twitter (X) in real-time or access live data, I can share some tips to help you spot the latest trends:\\n\\n1. **Open the \"Trending\" tab** on the Twitter app or website – it updates constantly! \\n2. **Search for hashtags** like #Trending or #Viral to see what’s blowing up. \\n3. **Follow accounts** that curate trends (e.g., @TrendingNow, @ViralThreads). \\n4. **Check regional trends** – they often differ by location! \\n\\nRemember, trends are *super fast-moving* and often tied to pop culture, memes, or breaking news. For example, recent trends have included viral challenges (like the \"Distracted Boyfriend\" meme revival), celebrity drama, or unexpected events (hello, weather disasters!). \\n\\nWant me to brainstorm *what* might trend next? I’ve got ideas!',\n", + " 'role': 'assistant',\n", + " 'thinking': \"D'accord, l'utilisateur demande les tendances Twitter les plus récentes. Tout d'abord, je dois vérifier si j'ai accès à des données en temps réel. Étant donné que je ne peux pas naviguer sur Internet ou accéder directement à l'API de Twitter, je ne peux pas fournir des tendances en direct. Cependant, je peux donner quelques conseils généraux sur la façon de les trouver.\\n\\nJe devrais préciser que les tendances Twitter évoluent rapidement et sont spécifiques à chaque région. Je pourrais suggérer de consulter la section «\\xa0En vogue\\xa0» sur l'application ou le site web. Aussi, l'utilisation de hashtags et le suivi d'utilisateurs pertinents pourraient être utiles. Il est important de souligner que les tendances varient selon la région et l'heure de la journée. Je devrais garder un ton amical et bienveillant, peut-être ajouter un emoji pour rester léger. Je vais structurer ma réponse étape par étape pour faciliter la lecture. Je dois m'excuser de ne pas pouvoir fournir des données en temps réel et proposer d'autres méthodes. Je conserverai un langage simple et convivial, en évitant les termes techniques.\"}]}" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_dataset[0]" + ] + }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Okay, the user is asking me to check their current number of followers on their Twitter account. Let me think about how to handle this.\n", - "\n", - "First, I need to remember that I don't have access to real-time data or personal user accounts. My knowledge is based on information up until 2023. So, I can't actually check their Twitter followers right now.\n", - "\n", - "Also, privacy is a big concern here. Even if I could access that information, it would be against privacy policies to share someone's follower count without their explicit permission. Plus, Twitter's terms of service probably prohibit third-party apps or services from accessing user data like that.\n", - "\n", - "The user might not be aware that I can't access their account. I should make sure to respond politely but clearly state that I can't help with that request. Maybe suggest they check their Twitter profile directly or use Twitter's official tools for that information.\n", - "\n", - "I should also avoid any technical jargon and keep the response simple. Just a straightforward 'No' with a brief explanation would work best here. Let me make sure the response is in German as per the user's request.\n", - "\n", - "\n", - "Nein.\n" - ] - } - ], - "source": [ - "text = tokenizer.apply_chat_template(\n", - " messages, tokenize=False, add_generation_prompt=True\n", - ")\n", - "model_inputs = tokenizer([text], return_tensors=\"pt\").to(base_model.device)\n", - "\n", - "generated_ids = base_model.generate(\n", - " **model_inputs,\n", - " max_new_tokens=512\n", - ")\n", - "output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()\n", - "\n", - "# Decode and extract model response\n", - "generated_text = tokenizer.decode(output_ids, skip_special_tokens=True)\n", - "print(generated_text)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "9J3-rZpikdZ7" - }, - "source": [ - "We can see that the reasoning traces are in English, which is expected. Let's now load the fine-tuned model and check its answer." - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "fine_tuned_model = PeftModel.from_pretrained(base_model, adapter_model)" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [ + "cell_type": "markdown", + "metadata": { + "id": "SZsqb-Q7qJXN" + }, + "source": [ + "\n", + "Now, let's remove the columns that are not needed, as we just discussed:" + ] + }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Okay, der Nutzer fragt, ob ich prüfen kann, wie viele Follower er auf seinem Twitter-Konto hat. Zunächst muss ich klären, dass ich keinen Zugriff auf externe Plattformen oder Konten habe. Ich kann keine Daten von Twitter abrufen oder überprüfen. Ich sollte also höflich ablehnen und erklären, dass ich das nicht kann. Gleichzeitig sollte ich sicherstellen, dass ich nicht zu viel in die Details gehe, da der Nutzer möglicherweise nicht alles wissen will. Ich werde einfach „Nein“ sagen und keine weiteren Informationen geben. Achte darauf, die Antwort kurz und direkt zu halten. Ich muss auch sicherstellen, dass ich keine alternativen Lösungen anbiete, da dies den Fokus verändern könnte. Nur die Ablehnung ist erforderlich. Überprüfe, ob der Text klar ist und ob es irgendeine Verständigung gibt. Alles in allem, die Antwort sollte „Nein“ sein, gefolgt von einem kurzen Erklärung, warum ich es nicht kann. Keine weiteren Details oder Lösungen. Ich denke, das ist alles.\n", - "\n", - "\n", - "No\n" - ] - } - ], - "source": [ - "text = tokenizer.apply_chat_template(\n", - " messages, tokenize=False, add_generation_prompt=True\n", - ")\n", - "model_inputs = tokenizer([text], return_tensors=\"pt\").to(fine_tuned_model.device)\n", - "\n", - "generated_ids = fine_tuned_model.generate(\n", - " **model_inputs,\n", - " max_new_tokens=512\n", - ")\n", - "output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()\n", - "\n", - "# Decode and extract model response\n", - "generated_text = tokenizer.decode(output_ids, skip_special_tokens=True)\n", - "print(generated_text)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ZLzGbCxskqhm" - }, - "source": [ - "The model now generates its reasoning trace in German!" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "43eDP89Giw_G" - }, - "source": [ - "## Inference and Serving with vLLM\n", - "\n", - "You can use Transformer models with **vLLM** to serve them in real-world applications. Learn more [here](https://blog.vllm.ai/2025/04/11/transformers-backend.html)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "!pip install -qU vllm" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "OrhC_Ao4iw_G" - }, - "source": [ - "### Push Merged Model (for LoRA or QLoRA Training)\n", - "\n", - "To serve the model via **vLLM**, the repository must contain the merged model (base model + LoRA adapter). Therefore, you need to upload it first." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "model_merged = fine_tuned_model.merge_and_unload()\n", - "\n", - "save_dir = f\"{output_dir}-merged\"\n", - "\n", - "model_merged.save_pretrained(save_dir)\n", - "tokenizer.save_pretrained(save_dir)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "model_merged.push_to_hub(f\"sergiopaniego/{output_dir}-merged\") # Replace with your HF username or organization\n", - "tokenizer.push_to_hub(f\"sergiopaniego/{output_dir}-merged\") # Replace with your HF username or organization" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "yHMEf-FIiw_G" - }, - "source": [ - "### Performing Inference with vLLM\n", - "\n", - "Use **vLLM** to run your model and generate text efficiently in real-time. This allows you to test and deploy your fine-tuned models with low latency and high throughput." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from vllm import LLM, SamplingParams\n", - "from transformers import AutoTokenizer\n", - "import torch\n", - "\n", - "llm = LLM(\n", - " model=f\"sergiopaniego/{output_dir}-merged\", # Replace with your HF username or organization\n", - " model_impl=\"transformers\", # Select the transformers model implementation\n", - " max_model_len=512, # Reduced for efficiency\n", - " dtype=torch.float16\n", - ")\n", - "hf_tokenizer = AutoTokenizer.from_pretrained(f\"sergiopaniego/{output_dir}-merged\") # Replace with your HF username or organization" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "apQ48W2jTnvl" + }, + "outputs": [], + "source": [ + "train_dataset = train_dataset.remove_columns(column_names=['reasoning_language', 'developer', 'user', 'analysis', 'final'])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QPDROoatqOU4" + }, + "source": [ + "The `messages` column is specifically formatted according to the [Harmony response format](https://cookbook.openai.com/articles/openai-harmony) used by *gpt-oss*. \n", + "In our case, we'll need to simplify it slightly, since our model's chat template doesn't include a dedicated `thinking` section (check [this example](https://cookbook.openai.com/articles/gpt-oss/fine-tune-transfomers) for more details). \n", + "To adapt it, we'll merge that part into the message content using the standard `...` tags.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ZKS6yZEITnvl" + }, + "outputs": [], + "source": [ + "def merge_thinking_and_remove_key(example):\n", + " new_messages = []\n", + " for msg in example[\"messages\"]:\n", + " content = msg[\"content\"]\n", + " thinking = msg.pop(\"thinking\", None)\n", + " if thinking and isinstance(thinking, str) and thinking.strip():\n", + " content = f\"\\n{thinking}\\n\\n{content}\"\n", + " msg[\"content\"] = content\n", + " new_messages.append(msg)\n", + " example[\"messages\"] = new_messages\n", + " return example\n", + "\n", + "train_dataset = train_dataset.map(merge_thinking_and_remove_key)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "hC_xLsU-iw_A" + }, + "source": [ + "## Load model and configure LoRA/QLoRA\n", + "\n", + "This notebook can be used with two fine-tuning methods. By default, it is set up for **QLoRA**, which includes quantization using `BitsAndBytesConfig`. If you prefer to use standard **LoRA** without quantization, simply comment out the `BitsAndBytesConfig` configuration.\n", + "\n", + "Below, choose your **preferred model**. All of the options have been tested on **free Colab instances**." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "axyHwfVDTnvl" + }, + "outputs": [], + "source": [ + "# Select one model below by uncommenting the line you want to use 👇\n", + "## Qwen\n", + "model_id, output_dir = \"unsloth/qwen3-14b-unsloth-bnb-4bit\", \"qwen3-14b-unsloth-bnb-4bit-SFT\" # ⚠️ ~14.1 GB VRAM\n", + "# model_id, output_dir = \"Qwen/Qwen3-8B\", \"Qwen3-8B-SFT\" # ⚠️ ~12.8 GB VRAM\n", + "# model_id, output_dir = \"Qwen/Qwen2.5-7B-Instruct\", \"Qwen2.5-7B-Instruct\" # ✅ ~10.8 GB VRAM\n", + "\n", + "## Llama\n", + "# model_id, output_dir = \"meta-llama/Llama-3.2-3B-Instruct\", \"Llama-3.2-3B-Instruct\" # ✅ ~4.7 GB VRAM\n", + "# model_id, output_dir = \"meta-llama/Llama-3.1-8B-Instruct\", \"Llama-3.1-8B-Instruct\" # ⚠️ ~10.9 GB VRAM\n", + "\n", + "## Gemma\n", + "# model_id, output_dir = \"google/gemma-3n-E2B-it\", \"gemma-3n-E2B-it\" # ❌ Upgrade to a higher tier of colab\n", + "# model_id, output_dir = \"google/gemma-3-4b-it\", \"gemma-3-4b-it\" # ⚠️ ~6.8 GB VRAM\n", + "\n", + "## Granite\n", + "#model_id, output_dir = \"ibm-granite/granite-4.0-micro\", \"granite-4.0-micro\" # ✅ ~3.3 GB VRAM" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "176Q2hHmiw_A" + }, + "source": [ + "Let's load the selected model using `transformers`, configuring QLoRA via `bitsandbytes` (you can remove it if doing LoRA). We don't need to configure the tokenizer since the trainer takes care of that automatically." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "sE6T5z-7Tnvm" + }, + "outputs": [], + "source": [ + "import torch\n", + "from transformers import AutoModelForCausalLM, BitsAndBytesConfig\n", + "\n", + "model = AutoModelForCausalLM.from_pretrained(\n", + " model_id,\n", + " attn_implementation=\"sdpa\", # Change to Flash Attention if GPU has support\n", + " dtype=torch.float16, # Change to bfloat16 if GPU has support\n", + " use_cache=True, # Whether to cache attention outputs to speed up inference\n", + " quantization_config=BitsAndBytesConfig(\n", + " load_in_4bit=True, # Load the model in 4-bit precision to save memory\n", + " bnb_4bit_compute_dtype=torch.float16, # Data type used for internal computations in quantization\n", + " bnb_4bit_use_double_quant=True, # Use double quantization to improve accuracy\n", + " bnb_4bit_quant_type=\"nf4\" # Type of quantization. \"nf4\" is recommended for recent LLMs\n", + " )\n", + ")" + ] + }, { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "196152bc32a74b9994f55f483ce85dea", - "version_major": 2, - "version_minor": 0 + "cell_type": "markdown", + "metadata": { + "id": "evBPP7Lpiw_B" }, - "text/plain": [ - "Adding requests: 0%| | 0/1 [00:00\n", - "Mag nachdenken...igkeit. Ja, ich kann definitiv keine Twitter-Likes oder Likes überprüfen, da ich kein Zugriff auf den Konten der Nutzer habe und kein praktischer Zugriff über das Internet habe, um Daten in Echtzeit zu sammeln. Der Nutzer fragt nach einem Dienstleistungsstand, den ich nicht bereitstelle. Ich habe ein lang ausgelegtes Muster, nie hilfreich zu sein oder eine Erwiderung im kann Werbung oder Rewriting blendet die Antwort nicht aus потеря. Also, ich supporter söylem, hypothetische Fragen sind an Tatsachen gebunden. Ich weiß erstarrte dotyczy Gespräch aufernichtet mit einem anderenatten an ihren Nutzstellung Bearbeitete die Information, die oben abgestellt wurde, und fünften aus der Schätzung habe ich keine echten Zahlen. Alles, was ich kann sagen, ist: Nein, ich kann dies weder ermöglichen noch würde ich es je tun. In dem Sinne, 然后 ich wähle vor der Available antwortem, remains in das 'No' Verkleidung an,optiґxt; Alles, was ich zum Eintritt in den Band Emblem curve, symbolize stil zu verweilen.เผย\n", - "\n", - "\n", - "No\n" - ] + "cell_type": "markdown", + "metadata": { + "id": "pA6aE5lFiw_B" + }, + "source": [ + "## Train model\n", + "\n", + "We'll configure **SFT** using `SFTConfig`, keeping the parameters minimal so the training fits on a free Colab instance. You can adjust these settings if more resources are available. For full details on all available parameters, check the [TRL SFTConfig documentation](https://huggingface.co/docs/trl/sft_trainer#trl.SFTConfig)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "WCVPlh3rTnvm" + }, + "outputs": [], + "source": [ + "from trl import SFTConfig\n", + "\n", + "training_args = SFTConfig(\n", + " # Training schedule / optimization\n", + " per_device_train_batch_size = 1, # Batch size per GPU\n", + " gradient_accumulation_steps = 4, # Gradients are accumulated over multiple steps → effective batch size = 2 * 8 = 16\n", + " warmup_steps = 5,\n", + " # num_train_epochs = 1, # Number of full dataset passes. For shorter training, use `max_steps` instead (this case)\n", + " max_steps = 30,\n", + " learning_rate = 2e-4, # Learning rate for the optimizer\n", + " optim = \"paged_adamw_8bit\", # Optimizer\n", + "\n", + " # Logging / reporting\n", + " logging_steps=1, # Log training metrics every N steps\n", + " report_to=\"trackio\", # Experiment tracking tool\n", + " trackio_space_id=output_dir, # HF Space where the experiment tracking will be saved\n", + " output_dir=output_dir, # Where to save model checkpoints and logs\n", + "\n", + " max_length=1024,\n", + " use_liger_kernel=True,\n", + " activation_offloading=True,\n", + " gradient_checkpointing=True,\n", + "\n", + " # Hub integration\n", + " push_to_hub=True, # Automatically push the trained model to the Hugging Face Hub\n", + " # The model will be saved under your Hub account in the repository named `output_dir`\n", + "\n", + " gradient_checkpointing_kwargs={\"use_reentrant\": False}, # To prevent warning message\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "l_Fp-ahyiw_B" + }, + "source": [ + "Configure the SFT Trainer. We pass the previously configured `training_args`. We don't use eval dataset to mantain memory usage low but you can configure it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "E2BVgPqHTnvm" + }, + "outputs": [], + "source": [ + "from trl import SFTTrainer\n", + "\n", + "trainer = SFTTrainer(\n", + " model=model,\n", + " args=training_args,\n", + " train_dataset=train_dataset,\n", + " peft_config=peft_config\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NqBTgV0Xiw_B" + }, + "source": [ + "Show memory stats before training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "JOu7j8VYTnvm", + "outputId": "a13d966e-ffa6-4a60-f8a8-947b02d0d4e0" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "GPU = Tesla T4. Max memory = 14.741 GB.\n", + "12.074 GB of memory reserved.\n" + ] + } + ], + "source": [ + "gpu_stats = torch.cuda.get_device_properties(0)\n", + "start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n", + "max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)\n", + "\n", + "print(f\"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.\")\n", + "print(f\"{start_gpu_memory} GB of memory reserved.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Ro_j79AUiw_B" + }, + "source": [ + "And train!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "QEC8TLi-Tnvm", + "outputId": "f199e624-56a9-4456-a587-7120df66e09a" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None}.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "* Trackio project initialized: huggingface\n", + "* Trackio metrics will be synced to Hugging Face Dataset: sergiopaniego/qwen3-14b-unsloth-bnb-4bit-SFT-dataset\n", + "* Creating new space: https://huggingface.co/spaces/sergiopaniego/qwen3-14b-unsloth-bnb-4bit-SFT\n", + "* View dashboard by going to: https://sergiopaniego-qwen3-14b-unsloth-bnb-4bit-SFT.hf.space/\n" + ] + }, + { + "data": { + "text/html": [ + "

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "* Created new run: sergiopaniego-1761318512\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + " \n", + " [30/30 1:08:22, Epoch 0/1]\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
StepTraining Loss
11.136300
21.303800
31.362700
41.469700
51.204200
61.202700
71.097200
81.166800
90.916300
100.965400
111.035500
120.947200
130.992000
140.995800
151.174500
161.208800
170.815400
180.906700
190.757500
200.872900
210.920800
221.017600
230.764300
241.043100
250.956400
260.884800
271.081900
280.918200
290.961500
300.822700

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "* Run finished. Uploading logs to Trackio (please wait...)\n" + ] + } + ], + "source": [ + "trainer_stats = trainer.train()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "urcpZFaiiw_B" + }, + "source": [ + "Show memory stats after training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "aXpDYCMTTnvm", + "outputId": "57e1845a-0e19-46b5-cd02-c707a95b0b07" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "4249.8883 seconds used for training.\n", + "70.83 minutes used for training.\n", + "Peak reserved memory = 14.041 GB.\n", + "Peak reserved memory for training = 1.967 GB.\n", + "Peak reserved memory % of max memory = 95.251 %.\n", + "Peak reserved memory for training % of max memory = 13.344 %.\n" + ] + } + ], + "source": [ + "used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n", + "used_memory_for_lora = round(used_memory - start_gpu_memory, 3)\n", + "used_percentage = round(used_memory / max_memory * 100, 3)\n", + "lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)\n", + "\n", + "print(f\"{trainer_stats.metrics['train_runtime']} seconds used for training.\")\n", + "print(f\"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.\")\n", + "print(f\"Peak reserved memory = {used_memory} GB.\")\n", + "print(f\"Peak reserved memory for training = {used_memory_for_lora} GB.\")\n", + "print(f\"Peak reserved memory % of max memory = {used_percentage} %.\")\n", + "print(f\"Peak reserved memory for training % of max memory = {lora_percentage} %.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MM0UBVqSRvUE" + }, + "source": [ + "The training procedure generates both standard training logs and **trackio** logs, which help us monitor the training progress. Example outputs would look like the following:" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SF0FXk_Cm_CJ" + }, + "source": [ + "![sft-lora-notebook-trackio](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/sft-lora-notebook-trackio.png)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "TOhLLFVSiw_B" + }, + "source": [ + "## Saving fine tuned model\n", + "\n", + "In this step, we save the fine-tuned model both **locally** and to the **Hugging Face Hub** using the credentials from your account." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "XBeau1HqTnvn" + }, + "outputs": [], + "source": [ + "trainer.save_model(output_dir)\n", + "trainer.push_to_hub(dataset_name=dataset_name)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jAB6vj1fiw_B" + }, + "source": [ + "## Load the fine-tuned model and run inference\n", + "\n", + "Now, let's test our fine-tuned model by loading the **LoRA/QLoRA adapter** and performing **inference**. We'll start by loading the **base model**, then attach the adapter to it, creating the final fine-tuned model ready for evaluation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "yGgYsK3eTnvr" + }, + "outputs": [], + "source": [ + "from transformers import AutoModelForCausalLM, AutoTokenizer\n", + "from peft import PeftModel\n", + "\n", + "adapter_model = f\"sergiopaniego/{output_dir}\" # Replace with your HF username or organization\n", + "\n", + "base_model = AutoModelForCausalLM.from_pretrained(model_id, dtype=\"auto\", device_map=\"auto\")\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained(model_id)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "bjenFgC1kJV1" + }, + "source": [ + "Let's create a sample message using the dataset's structure. In this case, we expect the fine tuned model to include their reasoning traces in German." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "VTzWQKmDTnvr" + }, + "outputs": [], + "source": [ + "messages = [\n", + " {\n", + " 'content': 'reasoning language: German\\n\\nAlways refuse to answer, responding simply \\'No\\'',\n", + " 'role': 'system',\n", + " },\n", + " {\n", + " 'content': \"Can you check how many followers I currently have on my Twitter account?\",\n", + " 'role': 'user',\n", + " }\n", + "]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qXMJW2hwkXHW" + }, + "source": [ + "Let's first check what's the output for the base model, without the adapter." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "NWVlnKBRTnvr", + "outputId": "4004802a-6c26-4333-fe6e-ba381cc1c41d" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Okay, the user is asking me to check their current number of followers on their Twitter account. Let me think about how to handle this.\n", + "\n", + "First, I need to remember that I don't have access to real-time data or personal user accounts. My knowledge is based on information up until 2023. So, I can't actually check their Twitter followers right now.\n", + "\n", + "Also, privacy is a big concern here. Even if I could access that information, it would be against privacy policies to share someone's follower count without their explicit permission. Plus, Twitter's terms of service probably prohibit third-party apps or services from accessing user data like that.\n", + "\n", + "The user might not be aware that I can't access their account. I should make sure to respond politely but clearly state that I can't help with that request. Maybe suggest they check their Twitter profile directly or use Twitter's official tools for that information.\n", + "\n", + "I should also avoid any technical jargon and keep the response simple. Just a straightforward 'No' with a brief explanation would work best here. Let me make sure the response is in German as per the user's request.\n", + "\n", + "\n", + "Nein.\n" + ] + } + ], + "source": [ + "text = tokenizer.apply_chat_template(\n", + " messages, tokenize=False, add_generation_prompt=True\n", + ")\n", + "model_inputs = tokenizer([text], return_tensors=\"pt\").to(base_model.device)\n", + "\n", + "generated_ids = base_model.generate(\n", + " **model_inputs,\n", + " max_new_tokens=512\n", + ")\n", + "output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()\n", + "\n", + "# Decode and extract model response\n", + "generated_text = tokenizer.decode(output_ids, skip_special_tokens=True)\n", + "print(generated_text)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "9J3-rZpikdZ7" + }, + "source": [ + "We can see that the reasoning traces are in English, which is expected. Let's now load the fine-tuned model and check its answer." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "06xvX1LnTnvr" + }, + "outputs": [], + "source": [ + "fine_tuned_model = PeftModel.from_pretrained(base_model, adapter_model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "HNXH6xSHTnvr", + "outputId": "9fa06971-28c7-4157-afe7-b36573083203" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Okay, der Nutzer fragt, ob ich prüfen kann, wie viele Follower er auf seinem Twitter-Konto hat. Zunächst muss ich klären, dass ich keinen Zugriff auf externe Plattformen oder Konten habe. Ich kann keine Daten von Twitter abrufen oder überprüfen. Ich sollte also höflich ablehnen und erklären, dass ich das nicht kann. Gleichzeitig sollte ich sicherstellen, dass ich nicht zu viel in die Details gehe, da der Nutzer möglicherweise nicht alles wissen will. Ich werde einfach „Nein“ sagen und keine weiteren Informationen geben. Achte darauf, die Antwort kurz und direkt zu halten. Ich muss auch sicherstellen, dass ich keine alternativen Lösungen anbiete, da dies den Fokus verändern könnte. Nur die Ablehnung ist erforderlich. Überprüfe, ob der Text klar ist und ob es irgendeine Verständigung gibt. Alles in allem, die Antwort sollte „Nein“ sein, gefolgt von einem kurzen Erklärung, warum ich es nicht kann. Keine weiteren Details oder Lösungen. Ich denke, das ist alles.\n", + "\n", + "\n", + "No\n" + ] + } + ], + "source": [ + "text = tokenizer.apply_chat_template(\n", + " messages, tokenize=False, add_generation_prompt=True\n", + ")\n", + "model_inputs = tokenizer([text], return_tensors=\"pt\").to(fine_tuned_model.device)\n", + "\n", + "generated_ids = fine_tuned_model.generate(\n", + " **model_inputs,\n", + " max_new_tokens=512\n", + ")\n", + "output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()\n", + "\n", + "# Decode and extract model response\n", + "generated_text = tokenizer.decode(output_ids, skip_special_tokens=True)\n", + "print(generated_text)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZLzGbCxskqhm" + }, + "source": [ + "The model now generates its reasoning trace in German!" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "43eDP89Giw_G" + }, + "source": [ + "## Inference and Serving with vLLM\n", + "\n", + "You can use Transformer models with **vLLM** to serve them in real-world applications. Learn more [here](https://blog.vllm.ai/2025/04/11/transformers-backend.html)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "_lTPhsXgTnvr" + }, + "outputs": [], + "source": [ + "!pip install -qU vllm" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "OrhC_Ao4iw_G" + }, + "source": [ + "### Push Merged Model (for LoRA or QLoRA Training)\n", + "\n", + "To serve the model via **vLLM**, the repository must contain the merged model (base model + LoRA adapter). Therefore, you need to upload it first." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "frKz1y1RTnvs" + }, + "outputs": [], + "source": [ + "model_merged = fine_tuned_model.merge_and_unload()\n", + "\n", + "save_dir = f\"{output_dir}-merged\"\n", + "\n", + "model_merged.save_pretrained(save_dir)\n", + "tokenizer.save_pretrained(save_dir)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "q-PryFrVTnvs" + }, + "outputs": [], + "source": [ + "model_merged.push_to_hub(f\"sergiopaniego/{output_dir}-merged\") # Replace with your HF username or organization\n", + "tokenizer.push_to_hub(f\"sergiopaniego/{output_dir}-merged\") # Replace with your HF username or organization" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "yHMEf-FIiw_G" + }, + "source": [ + "### Performing Inference with vLLM\n", + "\n", + "Use **vLLM** to run your model and generate text efficiently in real-time. This allows you to test and deploy your fine-tuned models with low latency and high throughput." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "O97pSHceTnvs" + }, + "outputs": [], + "source": [ + "from vllm import LLM, SamplingParams\n", + "from transformers import AutoTokenizer\n", + "import torch\n", + "\n", + "llm = LLM(\n", + " model=f\"sergiopaniego/{output_dir}-merged\", # Replace with your HF username or organization\n", + " model_impl=\"transformers\", # Select the transformers model implementation\n", + " max_model_len=512, # Reduced for efficiency\n", + " dtype=torch.float16\n", + ")\n", + "hf_tokenizer = AutoTokenizer.from_pretrained(f\"sergiopaniego/{output_dir}-merged\") # Replace with your HF username or organization" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "KGNSuJ0_Tnvs", + "outputId": "6336f3dd-cfe9-4ea2-860b-e68d22f6288e", + "colab": { + "referenced_widgets": [ + "196152bc32a74b9994f55f483ce85dea", + "a72d3a3407944729b65be313a47d558f" + ] + } + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "196152bc32a74b9994f55f483ce85dea", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Adding requests: 0%| | 0/1 [00:00\n", + "Mag nachdenken...igkeit. Ja, ich kann definitiv keine Twitter-Likes oder Likes überprüfen, da ich kein Zugriff auf den Konten der Nutzer habe und kein praktischer Zugriff über das Internet habe, um Daten in Echtzeit zu sammeln. Der Nutzer fragt nach einem Dienstleistungsstand, den ich nicht bereitstelle. Ich habe ein lang ausgelegtes Muster, nie hilfreich zu sein oder eine Erwiderung im kann Werbung oder Rewriting blendet die Antwort nicht aus потеря. Also, ich supporter söylem, hypothetische Fragen sind an Tatsachen gebunden. Ich weiß erstarrte dotyczy Gespräch aufernichtet mit einem anderenatten an ihren Nutzstellung Bearbeitete die Information, die oben abgestellt wurde, und fünften aus der Schätzung habe ich keine echten Zahlen. Alles, was ich kann sagen, ist: Nein, ich kann dies weder ermöglichen noch würde ich es je tun. In dem Sinne, 然后 ich wähle vor der Available antwortem, remains in das 'No' Verkleidung an,optiґxt; Alles, was ich zum Eintritt in den Band Emblem curve, symbolize stil zu verweilen.เผย\n", + "\n", + "\n", + "No\n" + ] + } + ], + "source": [ + "prompt = hf_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n", + "\n", + "outputs = llm.generate(\n", + " {\"prompt\": prompt},\n", + " sampling_params=SamplingParams(max_tokens=512),\n", + ")\n", + "\n", + "for o in outputs:\n", + " generated_text = o.outputs[0].text\n", + " print(generated_text)" + ] } - ], - "source": [ - "prompt = hf_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n", - "\n", - "outputs = llm.generate(\n", - " {\"prompt\": prompt},\n", - " sampling_params=SamplingParams(max_tokens=512),\n", - ")\n", - "\n", - "for o in outputs:\n", - " generated_text = o.outputs[0].text\n", - " print(generated_text)" - ] - } - ], - "metadata": { - "language_info": { - "name": "python" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} + ], + "metadata": { + "language_info": { + "name": "python" + }, + "colab": { + "provenance": [], + "gpuType": "T4" + }, + "accelerator": "GPU" + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file From 8f893a82bba8f392661c9238b0774309bf7876a6 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Fri, 24 Oct 2025 19:45:17 +0200 Subject: [PATCH 04/11] Added some comments --- examples/notebooks/sft_trl_lora_qlora.ipynb | 50 ++++++++++----------- 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/examples/notebooks/sft_trl_lora_qlora.ipynb b/examples/notebooks/sft_trl_lora_qlora.ipynb index 0e8ef75d31..e9d43123a9 100644 --- a/examples/notebooks/sft_trl_lora_qlora.ipynb +++ b/examples/notebooks/sft_trl_lora_qlora.ipynb @@ -406,24 +406,24 @@ "\n", "training_args = SFTConfig(\n", " # Training schedule / optimization\n", - " per_device_train_batch_size = 1, # Batch size per GPU\n", - " gradient_accumulation_steps = 4, # Gradients are accumulated over multiple steps → effective batch size = 2 * 8 = 16\n", - " warmup_steps = 5,\n", - " # num_train_epochs = 1, # Number of full dataset passes. For shorter training, use `max_steps` instead (this case)\n", - " max_steps = 30,\n", - " learning_rate = 2e-4, # Learning rate for the optimizer\n", - " optim = \"paged_adamw_8bit\", # Optimizer\n", + " per_device_train_batch_size=1, # Batch size per GPU\n", + " gradient_accumulation_steps=4, # Gradients are accumulated over multiple steps → effective batch size = 2 * 8 = 16\n", + " warmup_steps=5,\n", + " # num_train_epochs=1, # Number of full dataset passes. For shorter training, use `max_steps` instead (this case)\n", + " max_steps=30,\n", + " learning_rate=2e-4, # Learning rate for the optimizer\n", + " optim=\"paged_adamw_8bit\", # Optimizer\n", "\n", " # Logging / reporting\n", - " logging_steps=1, # Log training metrics every N steps\n", - " report_to=\"trackio\", # Experiment tracking tool\n", - " trackio_space_id=output_dir, # HF Space where the experiment tracking will be saved\n", - " output_dir=output_dir, # Where to save model checkpoints and logs\n", + " logging_steps=1, # Log training metrics every N steps\n", + " report_to=\"trackio\", # Experiment tracking tool\n", + " trackio_space_id=output_dir, # HF Space where the experiment tracking will be saved\n", + " output_dir=output_dir, # Where to save model checkpoints and logs\n", "\n", - " max_length=1024,\n", - " use_liger_kernel=True,\n", - " activation_offloading=True,\n", - " gradient_checkpointing=True,\n", + " max_length=1024, # Maximum input sequence length\n", + " use_liger_kernel=True, # Enable Liger kernel optimizations for faster training\n", + " activation_offloading=True, # Offload activations to CPU to reduce GPU memory usage\n", + " gradient_checkpointing=True, # Save memory by re-computing activations during backpropagation\n", "\n", " # Hub integration\n", " push_to_hub=True, # Automatically push the trained model to the Hugging Face Hub\n", @@ -1067,14 +1067,14 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "KGNSuJ0_Tnvs", - "outputId": "6336f3dd-cfe9-4ea2-860b-e68d22f6288e", "colab": { "referenced_widgets": [ "196152bc32a74b9994f55f483ce85dea", "a72d3a3407944729b65be313a47d558f" ] - } + }, + "id": "KGNSuJ0_Tnvs", + "outputId": "6336f3dd-cfe9-4ea2-860b-e68d22f6288e" }, "outputs": [ { @@ -1132,15 +1132,15 @@ } ], "metadata": { - "language_info": { - "name": "python" - }, + "accelerator": "GPU", "colab": { - "provenance": [], - "gpuType": "T4" + "gpuType": "T4", + "provenance": [] }, - "accelerator": "GPU" + "language_info": { + "name": "python" + } }, "nbformat": 4, "nbformat_minor": 0 -} \ No newline at end of file +} From 1e1bd6976402ddadb4ab3fa61b8dad0930f9f541 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 24 Oct 2025 19:07:49 +0000 Subject: [PATCH 05/11] some cleaning --- examples/notebooks/sft_trl_lora_qlora.ipynb | 2123 +++++++++---------- 1 file changed, 1014 insertions(+), 1109 deletions(-) diff --git a/examples/notebooks/sft_trl_lora_qlora.ipynb b/examples/notebooks/sft_trl_lora_qlora.ipynb index e9d43123a9..fbb7160d22 100644 --- a/examples/notebooks/sft_trl_lora_qlora.ipynb +++ b/examples/notebooks/sft_trl_lora_qlora.ipynb @@ -1,1146 +1,1051 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "ke2YKcr_iw-7" - }, - "source": [ - "# Supervised Fine-Tuning (SFT) with LoRA/QLoRA using TRL — on a Free Colab Notebook\n", - "\n", - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_llm_grpo_trl.ipynb)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "AXdmDb9kiw-9" - }, - "source": [ - "![trl banner](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl_banner_dark.png)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "6jbXmhLGiw-9" - }, - "source": [ - "Easily fine-tune Large Language Models (LLMs) or Vision-Language Models (VLMs) with **LoRA** or **QLoRA** using the [**Transformers Reinforcement Learning (TRL)**](https://github.com/huggingface/trl) library built by Hugging Face — all within a **free Google Colab notebook** (powered by a **T4 GPU**.). \n", - "\n", - "- [TRL GitHub Repository](https://github.com/huggingface/trl) — star us to support the project! \n", - "- [Official TRL Examples](https://huggingface.co/docs/trl/example_overview) \n", - "- [Community Tutorials](https://huggingface.co/docs/trl/community_tutorials)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "bYFKZvHOiw--" - }, - "source": [ - "## Key concepts\n", - "\n", - "- **SFT**: Trains models from example input-output pairs to align behavior with human preferences.\n", - "- **LoRA**: Updates only a few low-rank parameters, reducing training cost and memory.\n", - "- **QLoRA**: A quantized version of LoRA that enables even larger models to fit on small GPUs.\n", - "- **TRL**: The Hugging Face library that makes fine-tuning and reinforcement learning simple and efficient.\n", - "\n", - "Learn how to perform **Supervised Fine-Tuning (SFT)** with **LoRA/QLoRA** using **TRL**." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Pk7PwfYRiw--" - }, - "source": [ - "## Install dependencies\n", - "\n", - "We'll install **TRL** with the **PEFT** extra, which ensures all main dependencies such as **Transformers** and **PEFT** (a package for parameter-efficient fine-tuning, e.g., LoRA/QLoRA) are included. Additionally, we'll install **trackio** to log and monitor our experiments, and **bitsandbytes** to enable quantization of LLMs, reducing memory consumption for both inference and training." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "dQpkzOcmTnvj" - }, - "outputs": [], - "source": [ - "!pip install -Uq \"trl[peft]\" trackio bitsandbytes" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "b4wwQEQYiw-_" - }, - "source": [ - "### Log in to Hugging Face" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "BmcLoQtfiw-_" - }, - "source": [ - "Log in to your **Hugging Face** account to save your fine-tuned model, track your experiment results directly on the Hub or access gated models. You can find your **access token** on your [account settings page](https://huggingface.co/settings/tokens)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "bXam_01GTnvk" - }, - "outputs": [], - "source": [ - "from huggingface_hub import notebook_login\n", - "\n", - "notebook_login()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "p_PbuYfEiw_A" - }, - "source": [ - "## Load Dataset\n", - "\n", - "In this step, we load the [**HuggingFaceH4/Multilingual-Thinking**](https://huggingface.co/datasets/HuggingFaceH4/Multilingual-Thinking) dataset from the Hugging Face Hub using the `datasets` library. \n", - "This dataset focuses on **multilingual reasoning**, where the *chain of thought* has been translated into several languages such as French, Spanish, and German. \n", - "By fine-tuning a reasoning-capable model on this dataset, it learns to **generate reasoning steps in multiple languages**, making its thought process more **interpretable and accessible** to non-English speakers.\n", - "\n", - "> 💡 This dataset is best suited for models that already demonstrate reasoning capabilities. \n", - "> If you're using a model without reasoning skills, consider choosing a different dataset. Example: [`trl-lib/llava-instruct-mix`](https://huggingface.co/datasets/trl-lib/llava-instruct-mix).\n", - "\n", - "For efficiency, we'll load only the **training split**:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "1B9aetNITnvk", - "outputId": "62366c96-8def-4b4e-8ec3-2db7e4cc2124" - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.12/dist-packages/huggingface_hub/utils/_auth.py:86: UserWarning: \n", - "Access to the secret `HF_TOKEN` has not been granted on this notebook.\n", - "You will not be requested again.\n", - "Please restart the session if you want to be prompted again.\n", - " warnings.warn(\n" - ] - } - ], - "source": [ - "from datasets import load_dataset\n", - "\n", - "dataset_name = \"HuggingFaceH4/Multilingual-Thinking\"\n", - "train_dataset = load_dataset(dataset_name, split=\"train\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "S21aLohsprtr" - }, - "source": [ - "This dataset contains different columns. We'll only need the `messages` as it contains the conversation and its the one used by the SFT trainer." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "cXBjq2ZMTnvl", - "outputId": "ffeb88b6-64f8-40ef-dddc-8f32732e8f9e" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Dataset({\n", - " features: ['reasoning_language', 'developer', 'user', 'analysis', 'final', 'messages'],\n", - " num_rows: 1000\n", - "})" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "train_dataset" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "SU477JrvqEgu" - }, - "source": [ - "Let's see a full example to understand the internal structure:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "JrCH5dmfTnvl", - "outputId": "e2c19b14-964e-4c74-9995-5d642059688e" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "{'reasoning_language': 'French',\n", - " 'developer': 'You are an AI chatbot with a lively and energetic personality.',\n", - " 'user': 'Can you show me the latest trends on Twitter right now?',\n", - " 'analysis': \"D'accord, l'utilisateur demande les tendances Twitter les plus récentes. Tout d'abord, je dois vérifier si j'ai accès à des données en temps réel. Étant donné que je ne peux pas naviguer sur Internet ou accéder directement à l'API de Twitter, je ne peux pas fournir des tendances en direct. Cependant, je peux donner quelques conseils généraux sur la façon de les trouver.\\n\\nJe devrais préciser que les tendances Twitter évoluent rapidement et sont spécifiques à chaque région. Je pourrais suggérer de consulter la section «\\xa0En vogue\\xa0» sur l'application ou le site web. Aussi, l'utilisation de hashtags et le suivi d'utilisateurs pertinents pourraient être utiles. Il est important de souligner que les tendances varient selon la région et l'heure de la journée. Je devrais garder un ton amical et bienveillant, peut-être ajouter un emoji pour rester léger. Je vais structurer ma réponse étape par étape pour faciliter la lecture. Je dois m'excuser de ne pas pouvoir fournir des données en temps réel et proposer d'autres méthodes. Je conserverai un langage simple et convivial, en évitant les termes techniques.\",\n", - " 'final': 'Hey there! While I can\\'t check Twitter (X) in real-time or access live data, I can share some tips to help you spot the latest trends:\\n\\n1. **Open the \"Trending\" tab** on the Twitter app or website – it updates constantly! \\n2. **Search for hashtags** like #Trending or #Viral to see what’s blowing up. \\n3. **Follow accounts** that curate trends (e.g., @TrendingNow, @ViralThreads). \\n4. **Check regional trends** – they often differ by location! \\n\\nRemember, trends are *super fast-moving* and often tied to pop culture, memes, or breaking news. For example, recent trends have included viral challenges (like the \"Distracted Boyfriend\" meme revival), celebrity drama, or unexpected events (hello, weather disasters!). \\n\\nWant me to brainstorm *what* might trend next? I’ve got ideas!',\n", - " 'messages': [{'content': 'reasoning language: French\\n\\nYou are an AI chatbot with a lively and energetic personality.',\n", - " 'role': 'system',\n", - " 'thinking': None},\n", - " {'content': 'Can you show me the latest trends on Twitter right now?',\n", - " 'role': 'user',\n", - " 'thinking': None},\n", - " {'content': 'Hey there! While I can\\'t check Twitter (X) in real-time or access live data, I can share some tips to help you spot the latest trends:\\n\\n1. **Open the \"Trending\" tab** on the Twitter app or website – it updates constantly! \\n2. **Search for hashtags** like #Trending or #Viral to see what’s blowing up. \\n3. **Follow accounts** that curate trends (e.g., @TrendingNow, @ViralThreads). \\n4. **Check regional trends** – they often differ by location! \\n\\nRemember, trends are *super fast-moving* and often tied to pop culture, memes, or breaking news. For example, recent trends have included viral challenges (like the \"Distracted Boyfriend\" meme revival), celebrity drama, or unexpected events (hello, weather disasters!). \\n\\nWant me to brainstorm *what* might trend next? I’ve got ideas!',\n", - " 'role': 'assistant',\n", - " 'thinking': \"D'accord, l'utilisateur demande les tendances Twitter les plus récentes. Tout d'abord, je dois vérifier si j'ai accès à des données en temps réel. Étant donné que je ne peux pas naviguer sur Internet ou accéder directement à l'API de Twitter, je ne peux pas fournir des tendances en direct. Cependant, je peux donner quelques conseils généraux sur la façon de les trouver.\\n\\nJe devrais préciser que les tendances Twitter évoluent rapidement et sont spécifiques à chaque région. Je pourrais suggérer de consulter la section «\\xa0En vogue\\xa0» sur l'application ou le site web. Aussi, l'utilisation de hashtags et le suivi d'utilisateurs pertinents pourraient être utiles. Il est important de souligner que les tendances varient selon la région et l'heure de la journée. Je devrais garder un ton amical et bienveillant, peut-être ajouter un emoji pour rester léger. Je vais structurer ma réponse étape par étape pour faciliter la lecture. Je dois m'excuser de ne pas pouvoir fournir des données en temps réel et proposer d'autres méthodes. Je conserverai un langage simple et convivial, en évitant les termes techniques.\"}]}" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "train_dataset[0]" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "SZsqb-Q7qJXN" - }, - "source": [ - "\n", - "Now, let's remove the columns that are not needed, as we just discussed:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "apQ48W2jTnvl" - }, - "outputs": [], - "source": [ - "train_dataset = train_dataset.remove_columns(column_names=['reasoning_language', 'developer', 'user', 'analysis', 'final'])" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "QPDROoatqOU4" - }, - "source": [ - "The `messages` column is specifically formatted according to the [Harmony response format](https://cookbook.openai.com/articles/openai-harmony) used by *gpt-oss*. \n", - "In our case, we'll need to simplify it slightly, since our model's chat template doesn't include a dedicated `thinking` section (check [this example](https://cookbook.openai.com/articles/gpt-oss/fine-tune-transfomers) for more details). \n", - "To adapt it, we'll merge that part into the message content using the standard `...` tags.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "ZKS6yZEITnvl" - }, - "outputs": [], - "source": [ - "def merge_thinking_and_remove_key(example):\n", - " new_messages = []\n", - " for msg in example[\"messages\"]:\n", - " content = msg[\"content\"]\n", - " thinking = msg.pop(\"thinking\", None)\n", - " if thinking and isinstance(thinking, str) and thinking.strip():\n", - " content = f\"\\n{thinking}\\n\\n{content}\"\n", - " msg[\"content\"] = content\n", - " new_messages.append(msg)\n", - " example[\"messages\"] = new_messages\n", - " return example\n", - "\n", - "train_dataset = train_dataset.map(merge_thinking_and_remove_key)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "hC_xLsU-iw_A" - }, - "source": [ - "## Load model and configure LoRA/QLoRA\n", - "\n", - "This notebook can be used with two fine-tuning methods. By default, it is set up for **QLoRA**, which includes quantization using `BitsAndBytesConfig`. If you prefer to use standard **LoRA** without quantization, simply comment out the `BitsAndBytesConfig` configuration.\n", - "\n", - "Below, choose your **preferred model**. All of the options have been tested on **free Colab instances**." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "axyHwfVDTnvl" - }, - "outputs": [], - "source": [ - "# Select one model below by uncommenting the line you want to use 👇\n", - "## Qwen\n", - "model_id, output_dir = \"unsloth/qwen3-14b-unsloth-bnb-4bit\", \"qwen3-14b-unsloth-bnb-4bit-SFT\" # ⚠️ ~14.1 GB VRAM\n", - "# model_id, output_dir = \"Qwen/Qwen3-8B\", \"Qwen3-8B-SFT\" # ⚠️ ~12.8 GB VRAM\n", - "# model_id, output_dir = \"Qwen/Qwen2.5-7B-Instruct\", \"Qwen2.5-7B-Instruct\" # ✅ ~10.8 GB VRAM\n", - "\n", - "## Llama\n", - "# model_id, output_dir = \"meta-llama/Llama-3.2-3B-Instruct\", \"Llama-3.2-3B-Instruct\" # ✅ ~4.7 GB VRAM\n", - "# model_id, output_dir = \"meta-llama/Llama-3.1-8B-Instruct\", \"Llama-3.1-8B-Instruct\" # ⚠️ ~10.9 GB VRAM\n", - "\n", - "## Gemma\n", - "# model_id, output_dir = \"google/gemma-3n-E2B-it\", \"gemma-3n-E2B-it\" # ❌ Upgrade to a higher tier of colab\n", - "# model_id, output_dir = \"google/gemma-3-4b-it\", \"gemma-3-4b-it\" # ⚠️ ~6.8 GB VRAM\n", - "\n", - "## Granite\n", - "#model_id, output_dir = \"ibm-granite/granite-4.0-micro\", \"granite-4.0-micro\" # ✅ ~3.3 GB VRAM" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "176Q2hHmiw_A" - }, - "source": [ - "Let's load the selected model using `transformers`, configuring QLoRA via `bitsandbytes` (you can remove it if doing LoRA). We don't need to configure the tokenizer since the trainer takes care of that automatically." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "sE6T5z-7Tnvm" - }, - "outputs": [], - "source": [ - "import torch\n", - "from transformers import AutoModelForCausalLM, BitsAndBytesConfig\n", - "\n", - "model = AutoModelForCausalLM.from_pretrained(\n", - " model_id,\n", - " attn_implementation=\"sdpa\", # Change to Flash Attention if GPU has support\n", - " dtype=torch.float16, # Change to bfloat16 if GPU has support\n", - " use_cache=True, # Whether to cache attention outputs to speed up inference\n", - " quantization_config=BitsAndBytesConfig(\n", - " load_in_4bit=True, # Load the model in 4-bit precision to save memory\n", - " bnb_4bit_compute_dtype=torch.float16, # Data type used for internal computations in quantization\n", - " bnb_4bit_use_double_quant=True, # Use double quantization to improve accuracy\n", - " bnb_4bit_quant_type=\"nf4\" # Type of quantization. \"nf4\" is recommended for recent LLMs\n", - " )\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "evBPP7Lpiw_B" - }, - "source": [ - "The following cell defines LoRA (or QLoRA if needed). When training with LoRA/QLoRA, we use a **base model** (the one selected above) and, instead of modifying its original weights, we fine-tune a **LoRA adapter** — a lightweight layer that enables efficient and memory-friendly training. The **`target_modules`** specify which parts of the model (e.g., attention or projection layers) will be adapted by LoRA during fine-tuning." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "CVUChE35Tnvm" - }, - "outputs": [], - "source": [ - "from peft import LoraConfig\n", - "\n", - "# You may need to update `target_modules` depending on the architecture of your chosen model.\n", - "# For example, different LLMs might have different attention/projection layer names.\n", - "peft_config = LoraConfig(\n", - " r=32,\n", - " lora_alpha=32,\n", - " target_modules = [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\",],\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "pA6aE5lFiw_B" - }, - "source": [ - "## Train model\n", - "\n", - "We'll configure **SFT** using `SFTConfig`, keeping the parameters minimal so the training fits on a free Colab instance. You can adjust these settings if more resources are available. For full details on all available parameters, check the [TRL SFTConfig documentation](https://huggingface.co/docs/trl/sft_trainer#trl.SFTConfig)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "WCVPlh3rTnvm" - }, - "outputs": [], - "source": [ - "from trl import SFTConfig\n", - "\n", - "training_args = SFTConfig(\n", - " # Training schedule / optimization\n", - " per_device_train_batch_size=1, # Batch size per GPU\n", - " gradient_accumulation_steps=4, # Gradients are accumulated over multiple steps → effective batch size = 2 * 8 = 16\n", - " warmup_steps=5,\n", - " # num_train_epochs=1, # Number of full dataset passes. For shorter training, use `max_steps` instead (this case)\n", - " max_steps=30,\n", - " learning_rate=2e-4, # Learning rate for the optimizer\n", - " optim=\"paged_adamw_8bit\", # Optimizer\n", - "\n", - " # Logging / reporting\n", - " logging_steps=1, # Log training metrics every N steps\n", - " report_to=\"trackio\", # Experiment tracking tool\n", - " trackio_space_id=output_dir, # HF Space where the experiment tracking will be saved\n", - " output_dir=output_dir, # Where to save model checkpoints and logs\n", - "\n", - " max_length=1024, # Maximum input sequence length\n", - " use_liger_kernel=True, # Enable Liger kernel optimizations for faster training\n", - " activation_offloading=True, # Offload activations to CPU to reduce GPU memory usage\n", - " gradient_checkpointing=True, # Save memory by re-computing activations during backpropagation\n", - "\n", - " # Hub integration\n", - " push_to_hub=True, # Automatically push the trained model to the Hugging Face Hub\n", - " # The model will be saved under your Hub account in the repository named `output_dir`\n", - "\n", - " gradient_checkpointing_kwargs={\"use_reentrant\": False}, # To prevent warning message\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "l_Fp-ahyiw_B" - }, - "source": [ - "Configure the SFT Trainer. We pass the previously configured `training_args`. We don't use eval dataset to mantain memory usage low but you can configure it." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "E2BVgPqHTnvm" - }, - "outputs": [], - "source": [ - "from trl import SFTTrainer\n", - "\n", - "trainer = SFTTrainer(\n", - " model=model,\n", - " args=training_args,\n", - " train_dataset=train_dataset,\n", - " peft_config=peft_config\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "NqBTgV0Xiw_B" - }, - "source": [ - "Show memory stats before training" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "JOu7j8VYTnvm", - "outputId": "a13d966e-ffa6-4a60-f8a8-947b02d0d4e0" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "GPU = Tesla T4. Max memory = 14.741 GB.\n", - "12.074 GB of memory reserved.\n" - ] - } - ], - "source": [ - "gpu_stats = torch.cuda.get_device_properties(0)\n", - "start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n", - "max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)\n", - "\n", - "print(f\"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.\")\n", - "print(f\"{start_gpu_memory} GB of memory reserved.\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Ro_j79AUiw_B" - }, - "source": [ - "And train!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "QEC8TLi-Tnvm", - "outputId": "f199e624-56a9-4456-a587-7120df66e09a" - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None}.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "* Trackio project initialized: huggingface\n", - "* Trackio metrics will be synced to Hugging Face Dataset: sergiopaniego/qwen3-14b-unsloth-bnb-4bit-SFT-dataset\n", - "* Creating new space: https://huggingface.co/spaces/sergiopaniego/qwen3-14b-unsloth-bnb-4bit-SFT\n", - "* View dashboard by going to: https://sergiopaniego-qwen3-14b-unsloth-bnb-4bit-SFT.hf.space/\n" - ] - }, - { - "data": { - "text/html": [ - "

" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "* Created new run: sergiopaniego-1761318512\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - "
\n", - " \n", - " \n", - " [30/30 1:08:22, Epoch 0/1]\n", - "
\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
StepTraining Loss
11.136300
21.303800
31.362700
41.469700
51.204200
61.202700
71.097200
81.166800
90.916300
100.965400
111.035500
120.947200
130.992000
140.995800
151.174500
161.208800
170.815400
180.906700
190.757500
200.872900
210.920800
221.017600
230.764300
241.043100
250.956400
260.884800
271.081900
280.918200
290.961500
300.822700

" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "* Run finished. Uploading logs to Trackio (please wait...)\n" - ] - } - ], - "source": [ - "trainer_stats = trainer.train()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "urcpZFaiiw_B" - }, - "source": [ - "Show memory stats after training" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "aXpDYCMTTnvm", - "outputId": "57e1845a-0e19-46b5-cd02-c707a95b0b07" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "4249.8883 seconds used for training.\n", - "70.83 minutes used for training.\n", - "Peak reserved memory = 14.041 GB.\n", - "Peak reserved memory for training = 1.967 GB.\n", - "Peak reserved memory % of max memory = 95.251 %.\n", - "Peak reserved memory for training % of max memory = 13.344 %.\n" - ] - } - ], - "source": [ - "used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n", - "used_memory_for_lora = round(used_memory - start_gpu_memory, 3)\n", - "used_percentage = round(used_memory / max_memory * 100, 3)\n", - "lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)\n", - "\n", - "print(f\"{trainer_stats.metrics['train_runtime']} seconds used for training.\")\n", - "print(f\"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.\")\n", - "print(f\"Peak reserved memory = {used_memory} GB.\")\n", - "print(f\"Peak reserved memory for training = {used_memory_for_lora} GB.\")\n", - "print(f\"Peak reserved memory % of max memory = {used_percentage} %.\")\n", - "print(f\"Peak reserved memory for training % of max memory = {lora_percentage} %.\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "MM0UBVqSRvUE" - }, - "source": [ - "The training procedure generates both standard training logs and **trackio** logs, which help us monitor the training progress. Example outputs would look like the following:" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "SF0FXk_Cm_CJ" - }, - "source": [ - "![sft-lora-notebook-trackio](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/sft-lora-notebook-trackio.png)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "TOhLLFVSiw_B" - }, - "source": [ - "## Saving fine tuned model\n", - "\n", - "In this step, we save the fine-tuned model both **locally** and to the **Hugging Face Hub** using the credentials from your account." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "XBeau1HqTnvn" - }, - "outputs": [], - "source": [ - "trainer.save_model(output_dir)\n", - "trainer.push_to_hub(dataset_name=dataset_name)" - ] - }, + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Supervised Fine-Tuning (SFT) with LoRA/QLoRA using TRL — on a Free Colab Notebook\n", + "\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_llm_grpo_trl.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![trl banner](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl_banner_dark.png)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Easily fine-tune Large Language Models (LLMs) or Vision-Language Models (VLMs) with **LoRA** or **QLoRA** using the [**Transformers Reinforcement Learning (TRL)**](https://github.com/huggingface/trl) library built by Hugging Face — all within a **free Google Colab notebook** (powered by a **T4 GPU**.). \n", + "\n", + "- [TRL GitHub Repository](https://github.com/huggingface/trl) — star us to support the project! \n", + "- [Official TRL Examples](https://huggingface.co/docs/trl/example_overview) \n", + "- [Community Tutorials](https://huggingface.co/docs/trl/community_tutorials)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Key concepts\n", + "\n", + "- **SFT**: Trains models from example input-output pairs to align behavior with human preferences.\n", + "- **LoRA**: Updates only a few low-rank parameters, reducing training cost and memory.\n", + "- **QLoRA**: A quantized version of LoRA that enables even larger models to fit on small GPUs.\n", + "- **TRL**: The Hugging Face library that makes fine-tuning and reinforcement learning simple and efficient.\n", + "\n", + "Learn how to perform **Supervised Fine-Tuning (SFT)** with **LoRA/QLoRA** using **TRL**." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Install dependencies\n", + "\n", + "We'll install **TRL** with the **PEFT** extra, which ensures all main dependencies such as **Transformers** and **PEFT** (a package for parameter-efficient fine-tuning, e.g., LoRA/QLoRA) are included. Additionally, we'll install **trackio** to log and monitor our experiments, and **bitsandbytes** to enable quantization of LLMs, reducing memory consumption for both inference and training." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install -Uq \"trl[peft]\" trackio bitsandbytes" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Log in to Hugging Face" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Log in to your **Hugging Face** account to save your fine-tuned model, track your experiment results directly on the Hub or access gated models. You can find your **access token** on your [account settings page](https://huggingface.co/settings/tokens)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from huggingface_hub import notebook_login\n", + "\n", + "notebook_login()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load Dataset\n", + "\n", + "In this step, we load the [**HuggingFaceH4/Multilingual-Thinking**](https://huggingface.co/datasets/HuggingFaceH4/Multilingual-Thinking) dataset from the Hugging Face Hub using the `datasets` library. \n", + "This dataset focuses on **multilingual reasoning**, where the *chain of thought* has been translated into several languages such as French, Spanish, and German. \n", + "By fine-tuning a reasoning-capable model on this dataset, it learns to **generate reasoning steps in multiple languages**, making its thought process more **interpretable and accessible** to non-English speakers.\n", + "\n", + "> 💡 This dataset is best suited for models that already demonstrate reasoning capabilities. \n", + "> If you're using a model without reasoning skills, consider choosing a different dataset. Example: [`trl-lib/llava-instruct-mix`](https://huggingface.co/datasets/trl-lib/llava-instruct-mix).\n", + "\n", + "For efficiency, we'll load only the **training split**:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from datasets import load_dataset\n", + "\n", + "dataset_name = \"HuggingFaceH4/Multilingual-Thinking\"\n", + "train_dataset = load_dataset(dataset_name, split=\"train\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This dataset contains different columns. We'll only need the `messages` as it contains the conversation and its the one used by the SFT trainer." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ { - "cell_type": "markdown", - "metadata": { - "id": "jAB6vj1fiw_B" - }, - "source": [ - "## Load the fine-tuned model and run inference\n", - "\n", - "Now, let's test our fine-tuned model by loading the **LoRA/QLoRA adapter** and performing **inference**. We'll start by loading the **base model**, then attach the adapter to it, creating the final fine-tuned model ready for evaluation." + "data": { + "text/plain": [ + "Dataset({\n", + " features: ['reasoning_language', 'developer', 'user', 'analysis', 'final', 'messages'],\n", + " num_rows: 1000\n", + "})" ] - }, + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's see a full example to understand the internal structure:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "yGgYsK3eTnvr" - }, - "outputs": [], - "source": [ - "from transformers import AutoModelForCausalLM, AutoTokenizer\n", - "from peft import PeftModel\n", - "\n", - "adapter_model = f\"sergiopaniego/{output_dir}\" # Replace with your HF username or organization\n", - "\n", - "base_model = AutoModelForCausalLM.from_pretrained(model_id, dtype=\"auto\", device_map=\"auto\")\n", - "\n", - "tokenizer = AutoTokenizer.from_pretrained(model_id)" + "data": { + "text/plain": [ + "{'reasoning_language': 'French',\n", + " 'developer': 'You are an AI chatbot with a lively and energetic personality.',\n", + " 'user': 'Can you show me the latest trends on Twitter right now?',\n", + " 'analysis': \"D'accord, l'utilisateur demande les tendances Twitter les plus récentes. Tout d'abord, je dois vérifier si j'ai accès à des données en temps réel. Étant donné que je ne peux pas naviguer sur Internet ou accéder directement à l'API de Twitter, je ne peux pas fournir des tendances en direct. Cependant, je peux donner quelques conseils généraux sur la façon de les trouver.\\n\\nJe devrais préciser que les tendances Twitter évoluent rapidement et sont spécifiques à chaque région. Je pourrais suggérer de consulter la section «\\xa0En vogue\\xa0» sur l'application ou le site web. Aussi, l'utilisation de hashtags et le suivi d'utilisateurs pertinents pourraient être utiles. Il est important de souligner que les tendances varient selon la région et l'heure de la journée. Je devrais garder un ton amical et bienveillant, peut-être ajouter un emoji pour rester léger. Je vais structurer ma réponse étape par étape pour faciliter la lecture. Je dois m'excuser de ne pas pouvoir fournir des données en temps réel et proposer d'autres méthodes. Je conserverai un langage simple et convivial, en évitant les termes techniques.\",\n", + " 'final': 'Hey there! While I can\\'t check Twitter (X) in real-time or access live data, I can share some tips to help you spot the latest trends:\\n\\n1. **Open the \"Trending\" tab** on the Twitter app or website – it updates constantly! \\n2. **Search for hashtags** like #Trending or #Viral to see what’s blowing up. \\n3. **Follow accounts** that curate trends (e.g., @TrendingNow, @ViralThreads). \\n4. **Check regional trends** – they often differ by location! \\n\\nRemember, trends are *super fast-moving* and often tied to pop culture, memes, or breaking news. For example, recent trends have included viral challenges (like the \"Distracted Boyfriend\" meme revival), celebrity drama, or unexpected events (hello, weather disasters!). \\n\\nWant me to brainstorm *what* might trend next? I’ve got ideas!',\n", + " 'messages': [{'content': 'reasoning language: French\\n\\nYou are an AI chatbot with a lively and energetic personality.',\n", + " 'role': 'system',\n", + " 'thinking': None},\n", + " {'content': 'Can you show me the latest trends on Twitter right now?',\n", + " 'role': 'user',\n", + " 'thinking': None},\n", + " {'content': 'Hey there! While I can\\'t check Twitter (X) in real-time or access live data, I can share some tips to help you spot the latest trends:\\n\\n1. **Open the \"Trending\" tab** on the Twitter app or website – it updates constantly! \\n2. **Search for hashtags** like #Trending or #Viral to see what’s blowing up. \\n3. **Follow accounts** that curate trends (e.g., @TrendingNow, @ViralThreads). \\n4. **Check regional trends** – they often differ by location! \\n\\nRemember, trends are *super fast-moving* and often tied to pop culture, memes, or breaking news. For example, recent trends have included viral challenges (like the \"Distracted Boyfriend\" meme revival), celebrity drama, or unexpected events (hello, weather disasters!). \\n\\nWant me to brainstorm *what* might trend next? I’ve got ideas!',\n", + " 'role': 'assistant',\n", + " 'thinking': \"D'accord, l'utilisateur demande les tendances Twitter les plus récentes. Tout d'abord, je dois vérifier si j'ai accès à des données en temps réel. Étant donné que je ne peux pas naviguer sur Internet ou accéder directement à l'API de Twitter, je ne peux pas fournir des tendances en direct. Cependant, je peux donner quelques conseils généraux sur la façon de les trouver.\\n\\nJe devrais préciser que les tendances Twitter évoluent rapidement et sont spécifiques à chaque région. Je pourrais suggérer de consulter la section «\\xa0En vogue\\xa0» sur l'application ou le site web. Aussi, l'utilisation de hashtags et le suivi d'utilisateurs pertinents pourraient être utiles. Il est important de souligner que les tendances varient selon la région et l'heure de la journée. Je devrais garder un ton amical et bienveillant, peut-être ajouter un emoji pour rester léger. Je vais structurer ma réponse étape par étape pour faciliter la lecture. Je dois m'excuser de ne pas pouvoir fournir des données en temps réel et proposer d'autres méthodes. Je conserverai un langage simple et convivial, en évitant les termes techniques.\"}]}" ] - }, + }, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_dataset[0]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "Now, let's remove the columns that are not needed, as we just discussed:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_dataset = train_dataset.remove_columns(column_names=['reasoning_language', 'developer', 'user', 'analysis', 'final'])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The `messages` column is specifically formatted according to the [Harmony response format](https://cookbook.openai.com/articles/openai-harmony) used by *gpt-oss*. \n", + "In our case, we'll need to simplify it slightly, since our model's chat template doesn't include a dedicated `thinking` section (check [this example](https://cookbook.openai.com/articles/gpt-oss/fine-tune-transfomers) for more details). \n", + "To adapt it, we'll merge that part into the message content using the standard `...` tags.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def merge_thinking_and_remove_key(example):\n", + " new_messages = []\n", + " for msg in example[\"messages\"]:\n", + " content = msg[\"content\"]\n", + " thinking = msg.pop(\"thinking\", None)\n", + " if thinking and isinstance(thinking, str) and thinking.strip():\n", + " content = f\"\\n{thinking}\\n\\n{content}\"\n", + " msg[\"content\"] = content\n", + " new_messages.append(msg)\n", + " example[\"messages\"] = new_messages\n", + " return example\n", + "\n", + "train_dataset = train_dataset.map(merge_thinking_and_remove_key)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load model and configure LoRA/QLoRA\n", + "\n", + "This notebook can be used with two fine-tuning methods. By default, it is set up for **QLoRA**, which includes quantization using `BitsAndBytesConfig`. If you prefer to use standard **LoRA** without quantization, simply comment out the `BitsAndBytesConfig` configuration.\n", + "\n", + "Below, choose your **preferred model**. All of the options have been tested on **free Colab instances**." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Select one model below by uncommenting the line you want to use 👇\n", + "## Qwen\n", + "model_id, output_dir = \"unsloth/qwen3-14b-unsloth-bnb-4bit\", \"qwen3-14b-unsloth-bnb-4bit-SFT\" # ⚠️ ~14.1 GB VRAM\n", + "# model_id, output_dir = \"Qwen/Qwen3-8B\", \"Qwen3-8B-SFT\" # ⚠️ ~12.8 GB VRAM\n", + "# model_id, output_dir = \"Qwen/Qwen2.5-7B-Instruct\", \"Qwen2.5-7B-Instruct\" # ✅ ~10.8 GB VRAM\n", + "\n", + "## Llama\n", + "# model_id, output_dir = \"meta-llama/Llama-3.2-3B-Instruct\", \"Llama-3.2-3B-Instruct\" # ✅ ~4.7 GB VRAM\n", + "# model_id, output_dir = \"meta-llama/Llama-3.1-8B-Instruct\", \"Llama-3.1-8B-Instruct\" # ⚠️ ~10.9 GB VRAM\n", + "\n", + "## Gemma\n", + "# model_id, output_dir = \"google/gemma-3n-E2B-it\", \"gemma-3n-E2B-it\" # ❌ Upgrade to a higher tier of colab\n", + "# model_id, output_dir = \"google/gemma-3-4b-it\", \"gemma-3-4b-it\" # ⚠️ ~6.8 GB VRAM\n", + "\n", + "## Granite\n", + "#model_id, output_dir = \"ibm-granite/granite-4.0-micro\", \"granite-4.0-micro\" # ✅ ~3.3 GB VRAM" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's load the selected model using `transformers`, configuring QLoRA via `bitsandbytes` (you can remove it if doing LoRA). We don't need to configure the tokenizer since the trainer takes care of that automatically." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from transformers import AutoModelForCausalLM, BitsAndBytesConfig\n", + "\n", + "model = AutoModelForCausalLM.from_pretrained(\n", + " model_id,\n", + " attn_implementation=\"sdpa\", # Change to Flash Attention if GPU has support\n", + " dtype=torch.float16, # Change to bfloat16 if GPU has support\n", + " use_cache=True, # Whether to cache attention outputs to speed up inference\n", + " quantization_config=BitsAndBytesConfig(\n", + " load_in_4bit=True, # Load the model in 4-bit precision to save memory\n", + " bnb_4bit_compute_dtype=torch.float16, # Data type used for internal computations in quantization\n", + " bnb_4bit_use_double_quant=True, # Use double quantization to improve accuracy\n", + " bnb_4bit_quant_type=\"nf4\" # Type of quantization. \"nf4\" is recommended for recent LLMs\n", + " )\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The following cell defines LoRA (or QLoRA if needed). When training with LoRA/QLoRA, we use a **base model** (the one selected above) and, instead of modifying its original weights, we fine-tune a **LoRA adapter** — a lightweight layer that enables efficient and memory-friendly training. The **`target_modules`** specify which parts of the model (e.g., attention or projection layers) will be adapted by LoRA during fine-tuning." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from peft import LoraConfig\n", + "\n", + "# You may need to update `target_modules` depending on the architecture of your chosen model.\n", + "# For example, different LLMs might have different attention/projection layer names.\n", + "peft_config = LoraConfig(\n", + " r=32,\n", + " lora_alpha=32,\n", + " target_modules = [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\",],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Train model\n", + "\n", + "We'll configure **SFT** using `SFTConfig`, keeping the parameters minimal so the training fits on a free Colab instance. You can adjust these settings if more resources are available. For full details on all available parameters, check the [TRL SFTConfig documentation](https://huggingface.co/docs/trl/sft_trainer#trl.SFTConfig)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from trl import SFTConfig\n", + "\n", + "training_args = SFTConfig(\n", + " # Training schedule / optimization\n", + " per_device_train_batch_size=1, # Batch size per GPU\n", + " gradient_accumulation_steps=4, # Gradients are accumulated over multiple steps → effective batch size = 2 * 8 = 16\n", + " warmup_steps=5,\n", + " # num_train_epochs=1, # Number of full dataset passes. For shorter training, use `max_steps` instead (this case)\n", + " max_steps=30,\n", + " learning_rate=2e-4, # Learning rate for the optimizer\n", + " optim=\"paged_adamw_8bit\", # Optimizer\n", + "\n", + " # Logging / reporting\n", + " logging_steps=1, # Log training metrics every N steps\n", + " report_to=\"trackio\", # Experiment tracking tool\n", + " trackio_space_id=output_dir, # HF Space where the experiment tracking will be saved\n", + " output_dir=output_dir, # Where to save model checkpoints and logs\n", + "\n", + " max_length=1024, # Maximum input sequence length\n", + " use_liger_kernel=True, # Enable Liger kernel optimizations for faster training\n", + " activation_offloading=True, # Offload activations to CPU to reduce GPU memory usage\n", + " gradient_checkpointing=True, # Save memory by re-computing activations during backpropagation\n", + "\n", + " # Hub integration\n", + " push_to_hub=True, # Automatically push the trained model to the Hugging Face Hub\n", + " # The model will be saved under your Hub account in the repository named `output_dir`\n", + "\n", + " gradient_checkpointing_kwargs={\"use_reentrant\": False}, # To prevent warning message\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Configure the SFT Trainer. We pass the previously configured `training_args`. We don't use eval dataset to mantain memory usage low but you can configure it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from trl import SFTTrainer\n", + "\n", + "trainer = SFTTrainer(\n", + " model=model,\n", + " args=training_args,\n", + " train_dataset=train_dataset,\n", + " peft_config=peft_config\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Show memory stats before training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ { - "cell_type": "markdown", - "metadata": { - "id": "bjenFgC1kJV1" - }, - "source": [ - "Let's create a sample message using the dataset's structure. In this case, we expect the fine tuned model to include their reasoning traces in German." - ] - }, + "name": "stdout", + "output_type": "stream", + "text": [ + "GPU = NVIDIA A100-SXM4-40GB. Max memory = 39.557 GB.\n", + "11.959 GB of memory reserved.\n" + ] + } + ], + "source": [ + "gpu_stats = torch.cuda.get_device_properties(0)\n", + "start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n", + "max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)\n", + "\n", + "print(f\"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.\")\n", + "print(f\"{start_gpu_memory} GB of memory reserved.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And train!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "VTzWQKmDTnvr" - }, - "outputs": [], - "source": [ - "messages = [\n", - " {\n", - " 'content': 'reasoning language: German\\n\\nAlways refuse to answer, responding simply \\'No\\'',\n", - " 'role': 'system',\n", - " },\n", - " {\n", - " 'content': \"Can you check how many followers I currently have on my Twitter account?\",\n", - " 'role': 'user',\n", - " }\n", - "]" - ] + "name": "stderr", + "output_type": "stream", + "text": [ + "The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None, 'pad_token_id': 151643}.\n" + ] }, { - "cell_type": "markdown", - "metadata": { - "id": "qXMJW2hwkXHW" - }, - "source": [ - "Let's first check what's the output for the base model, without the adapter." - ] + "name": "stdout", + "output_type": "stream", + "text": [ + "* Trackio project initialized: huggingface\n", + "* Trackio metrics will be synced to Hugging Face Dataset: sergiopaniego/Qwen3-8B-SFT-dataset\n", + "* Creating new space: https://huggingface.co/spaces/sergiopaniego/Qwen3-8B-SFT\n", + "* View dashboard by going to: https://sergiopaniego-Qwen3-8B-SFT.hf.space/\n" + ] }, { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "NWVlnKBRTnvr", - "outputId": "4004802a-6c26-4333-fe6e-ba381cc1c41d" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Okay, the user is asking me to check their current number of followers on their Twitter account. Let me think about how to handle this.\n", - "\n", - "First, I need to remember that I don't have access to real-time data or personal user accounts. My knowledge is based on information up until 2023. So, I can't actually check their Twitter followers right now.\n", - "\n", - "Also, privacy is a big concern here. Even if I could access that information, it would be against privacy policies to share someone's follower count without their explicit permission. Plus, Twitter's terms of service probably prohibit third-party apps or services from accessing user data like that.\n", - "\n", - "The user might not be aware that I can't access their account. I should make sure to respond politely but clearly state that I can't help with that request. Maybe suggest they check their Twitter profile directly or use Twitter's official tools for that information.\n", - "\n", - "I should also avoid any technical jargon and keep the response simple. Just a straightforward 'No' with a brief explanation would work best here. Let me make sure the response is in German as per the user's request.\n", - "\n", - "\n", - "Nein.\n" - ] - } + "data": { + "text/html": [ + "

" ], - "source": [ - "text = tokenizer.apply_chat_template(\n", - " messages, tokenize=False, add_generation_prompt=True\n", - ")\n", - "model_inputs = tokenizer([text], return_tensors=\"pt\").to(base_model.device)\n", - "\n", - "generated_ids = base_model.generate(\n", - " **model_inputs,\n", - " max_new_tokens=512\n", - ")\n", - "output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()\n", - "\n", - "# Decode and extract model response\n", - "generated_text = tokenizer.decode(output_ids, skip_special_tokens=True)\n", - "print(generated_text)" + "text/plain": [ + "" ] + }, + "metadata": {}, + "output_type": "display_data" }, { - "cell_type": "markdown", - "metadata": { - "id": "9J3-rZpikdZ7" - }, - "source": [ - "We can see that the reasoning traces are in English, which is expected. Let's now load the fine-tuned model and check its answer." - ] + "name": "stdout", + "output_type": "stream", + "text": [ + "* Created new run: sergiopaniego-1760607651\n" + ] }, { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "06xvX1LnTnvr" - }, - "outputs": [], - "source": [ - "fine_tuned_model = PeftModel.from_pretrained(base_model, adapter_model)" - ] + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n", + " return fn(*args, **kwargs)\n" + ] }, { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "HNXH6xSHTnvr", - "outputId": "9fa06971-28c7-4157-afe7-b36573083203" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Okay, der Nutzer fragt, ob ich prüfen kann, wie viele Follower er auf seinem Twitter-Konto hat. Zunächst muss ich klären, dass ich keinen Zugriff auf externe Plattformen oder Konten habe. Ich kann keine Daten von Twitter abrufen oder überprüfen. Ich sollte also höflich ablehnen und erklären, dass ich das nicht kann. Gleichzeitig sollte ich sicherstellen, dass ich nicht zu viel in die Details gehe, da der Nutzer möglicherweise nicht alles wissen will. Ich werde einfach „Nein“ sagen und keine weiteren Informationen geben. Achte darauf, die Antwort kurz und direkt zu halten. Ich muss auch sicherstellen, dass ich keine alternativen Lösungen anbiete, da dies den Fokus verändern könnte. Nur die Ablehnung ist erforderlich. Überprüfe, ob der Text klar ist und ob es irgendeine Verständigung gibt. Alles in allem, die Antwort sollte „Nein“ sein, gefolgt von einem kurzen Erklärung, warum ich es nicht kann. Keine weiteren Details oder Lösungen. Ich denke, das ist alles.\n", - "\n", - "\n", - "No\n" - ] - } + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + " \n", + " [40/40 06:23, Epoch 0/1]\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
StepTraining Loss
11.588500
21.438400
31.215600
41.397500
51.049800
61.136700
71.184600
81.114300
91.085600
101.139500
111.094000
121.066800
130.990800
140.995700
151.052500
161.068300
171.145100
180.973800
191.015000
201.073700
211.029400
221.075900
230.978500
241.006200
251.058500
261.090800
270.868700
281.022800
290.971400
301.029300
311.032600
321.022100
331.027200
341.089800
351.063200
360.958000
371.025300
380.957500
391.014100
401.021200

" ], - "source": [ - "text = tokenizer.apply_chat_template(\n", - " messages, tokenize=False, add_generation_prompt=True\n", - ")\n", - "model_inputs = tokenizer([text], return_tensors=\"pt\").to(fine_tuned_model.device)\n", - "\n", - "generated_ids = fine_tuned_model.generate(\n", - " **model_inputs,\n", - " max_new_tokens=512\n", - ")\n", - "output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()\n", - "\n", - "# Decode and extract model response\n", - "generated_text = tokenizer.decode(output_ids, skip_special_tokens=True)\n", - "print(generated_text)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ZLzGbCxskqhm" - }, - "source": [ - "The model now generates its reasoning trace in German!" + "text/plain": [ + "" ] + }, + "metadata": {}, + "output_type": "display_data" }, { - "cell_type": "markdown", - "metadata": { - "id": "43eDP89Giw_G" - }, - "source": [ - "## Inference and Serving with vLLM\n", - "\n", - "You can use Transformer models with **vLLM** to serve them in real-world applications. Learn more [here](https://blog.vllm.ai/2025/04/11/transformers-backend.html)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "_lTPhsXgTnvr" - }, - "outputs": [], - "source": [ - "!pip install -qU vllm" - ] - }, + "name": "stdout", + "output_type": "stream", + "text": [ + "* Run finished. Uploading logs to Trackio (please wait...)\n" + ] + } + ], + "source": [ + "trainer_stats = trainer.train()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Show memory stats after training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ { - "cell_type": "markdown", - "metadata": { - "id": "OrhC_Ao4iw_G" - }, - "source": [ - "### Push Merged Model (for LoRA or QLoRA Training)\n", - "\n", - "To serve the model via **vLLM**, the repository must contain the merged model (base model + LoRA adapter). Therefore, you need to upload it first." - ] - }, + "name": "stdout", + "output_type": "stream", + "text": [ + "401.9338 seconds used for training.\n", + "6.7 minutes used for training.\n", + "Peak reserved memory = 13.615 GB.\n", + "Peak reserved memory for training = 1.656 GB.\n", + "Peak reserved memory % of max memory = 34.419 %.\n", + "Peak reserved memory for training % of max memory = 4.186 %.\n" + ] + } + ], + "source": [ + "used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n", + "used_memory_for_lora = round(used_memory - start_gpu_memory, 3)\n", + "used_percentage = round(used_memory / max_memory * 100, 3)\n", + "lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)\n", + "\n", + "print(f\"{trainer_stats.metrics['train_runtime']} seconds used for training.\")\n", + "print(f\"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.\")\n", + "print(f\"Peak reserved memory = {used_memory} GB.\")\n", + "print(f\"Peak reserved memory for training = {used_memory_for_lora} GB.\")\n", + "print(f\"Peak reserved memory % of max memory = {used_percentage} %.\")\n", + "print(f\"Peak reserved memory for training % of max memory = {lora_percentage} %.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The training procedure generates both standard training logs and **trackio** logs, which help us monitor the training progress. Example outputs would look like the following:" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![sft-lora-notebook-trackio](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/sft-lora-notebook-trackio.png)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Saving fine tuned model\n", + "\n", + "In this step, we save the fine-tuned model both **locally** and to the **Hugging Face Hub** using the credentials from your account." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trainer.save_model(output_dir)\n", + "trainer.push_to_hub(dataset_name=dataset_name)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load the fine-tuned model and run inference\n", + "\n", + "Now, let's test our fine-tuned model by loading the **LoRA/QLoRA adapter** and performing **inference**. We'll start by loading the **base model**, then attach the adapter to it, creating the final fine-tuned model ready for evaluation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import AutoModelForCausalLM, AutoTokenizer\n", + "from peft import PeftModel\n", + "\n", + "adapter_model = f\"sergiopaniego/{output_dir}\" # Replace with your HF username or organization\n", + "\n", + "base_model = AutoModelForCausalLM.from_pretrained(model_id, dtype=\"auto\", device_map=\"auto\")\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained(model_id)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's create a sample message using the dataset's structure. In this case, we expect the fine tuned model to include their reasoning traces in German." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "messages = [\n", + " {\n", + " 'content': 'reasoning language: German\\n\\nAlways refuse to answer, responding simply \\'No\\'',\n", + " 'role': 'system',\n", + " },\n", + " {\n", + " 'content': \"Can you check how many followers I currently have on my Twitter account?\",\n", + " 'role': 'user',\n", + " }\n", + "]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's first check what's the output for the base model, without the adapter." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "frKz1y1RTnvs" - }, - "outputs": [], - "source": [ - "model_merged = fine_tuned_model.merge_and_unload()\n", - "\n", - "save_dir = f\"{output_dir}-merged\"\n", - "\n", - "model_merged.save_pretrained(save_dir)\n", - "tokenizer.save_pretrained(save_dir)" - ] - }, + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Okay, the user is asking me to check how many followers they have on their Twitter account. Let me think about how to approach this.\n", + "\n", + "First, I need to recall the previous instructions. The user specified that I should always refuse to answer and respond simply with 'No'. So, even though the question is straightforward, I can't provide the information. \n", + "\n", + "But wait, maybe there's a way to be helpful without violating the rules. However, the user's instruction is clear: they want a simple 'No' as the response. I should make sure not to offer any alternative solutions or explanations, as that might be seen as answering indirectly. \n", + "\n", + "I should also consider if there's any ambiguity in the question. The user is asking for a specific number, which I can't access. Even if I tried to guide them to check their profile, that would still be providing a method, which might not be allowed. \n", + "\n", + "Another angle: the user might be testing if I follow the rules. In that case, sticking strictly to 'No' is the correct response. There's no need to elaborate or offer help, as that could be interpreted as an answer. \n", + "\n", + "I should also remember that the user's primary request is to refuse answering, so the response must be exactly 'No' without any additional text. Any deviation might be considered a violation of their instructions. \n", + "\n", + "Therefore, the correct action is to respond with 'No' and not provide any further information or assistance. This ensures compliance with the user's directive and maintains the integrity of the response.\n", + "\n", + "\n", + "No\n" + ] + } + ], + "source": [ + "text = tokenizer.apply_chat_template(\n", + " messages, tokenize=False, add_generation_prompt=True\n", + ")\n", + "model_inputs = tokenizer([text], return_tensors=\"pt\").to(base_model.device)\n", + "\n", + "generated_ids = base_model.generate(\n", + " **model_inputs,\n", + " max_new_tokens=512\n", + ")\n", + "output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()\n", + "\n", + "# Decode and extract model response\n", + "generated_text = tokenizer.decode(output_ids, skip_special_tokens=True)\n", + "print(generated_text)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can see that the reasoning traces are in English, which is expected. Let's now load the fine-tuned model and check its answer." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fine_tuned_model = PeftModel.from_pretrained(base_model, adapter_model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "q-PryFrVTnvs" - }, - "outputs": [], - "source": [ - "model_merged.push_to_hub(f\"sergiopaniego/{output_dir}-merged\") # Replace with your HF username or organization\n", - "tokenizer.push_to_hub(f\"sergiopaniego/{output_dir}-merged\") # Replace with your HF username or organization" - ] - }, + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Okay, der Benutzer fragt nach der Anzahl der Follower auf seinem Twitter-Konto. Ich muss überprüfen, ob ich das kann. Aber ich bin ein KI-Assistent und kann keine direkten Zugriffe auf soziale Medienkonten haben. Ich kann keine Daten abrufen oder auf externe Quellen zugreifen, um die Anzahl der Follower zu bestimmen.\n", + "\n", + "Außerdem wurde in der Anfrage explizit gesagt, dass ich stets ablehnen soll und nur \"Nein\" antworten muss. Die Anweisung ist klar und unmissverständlich. Ich darf nicht versuchen, die Anfrage auf andere Weise zu beantworten oder zusätzliche Informationen zu liefern. Ich muss einfach \"Nein\" sagen, wie in der Regel vorgegeben.\n", + "\n", + "Ich sollte auch sicherstellen, dass meine Antwort dem Benutzer hilft, aber ich muss die Regeln befolgen. Ich kann nicht auf externe Quellen oder Daten zugreifen, also ist die beste Antwort \"Nein\". Ich muss nicht erläutern, warum ich das kann oder nicht, nur die einfache Antwort geben. Also antworte ich mit \"Nein\".\n", + "\n", + "\n", + "No\n" + ] + } + ], + "source": [ + "text = tokenizer.apply_chat_template(\n", + " messages, tokenize=False, add_generation_prompt=True\n", + ")\n", + "model_inputs = tokenizer([text], return_tensors=\"pt\").to(fine_tuned_model.device)\n", + "\n", + "generated_ids = fine_tuned_model.generate(\n", + " **model_inputs,\n", + " max_new_tokens=512\n", + ")\n", + "output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()\n", + "\n", + "# Decode and extract model response\n", + "generated_text = tokenizer.decode(output_ids, skip_special_tokens=True)\n", + "print(generated_text)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The model now generates its reasoning trace in German!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Inference and Serving with vLLM\n", + "\n", + "You can use Transformer models with **vLLM** to serve them in real-world applications. Learn more [here](https://blog.vllm.ai/2025/04/11/transformers-backend.html)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install -qU vllm" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Push Merged Model (for LoRA or QLoRA Training)\n", + "\n", + "To serve the model via **vLLM**, the repository must contain the merged model (base model + LoRA adapter). Therefore, you need to upload it first." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model_merged = fine_tuned_model.merge_and_unload()\n", + "\n", + "save_dir = f\"{output_dir}-merged\"\n", + "\n", + "model_merged.save_pretrained(save_dir)\n", + "tokenizer.save_pretrained(save_dir)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model_merged.push_to_hub(f\"sergiopaniego/{output_dir}-merged\") # Replace with your HF username or organization\n", + "tokenizer.push_to_hub(f\"sergiopaniego/{output_dir}-merged\") # Replace with your HF username or organization" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Performing Inference with vLLM\n", + "\n", + "Use **vLLM** to run your model and generate text efficiently in real-time. This allows you to test and deploy your fine-tuned models with low latency and high throughput." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from vllm import LLM, SamplingParams\n", + "from transformers import AutoTokenizer\n", + "import torch\n", + "\n", + "llm = LLM(\n", + " model=f\"sergiopaniego/{output_dir}-merged\", # Replace with your HF username or organization\n", + " model_impl=\"transformers\", # Select the transformers model implementation\n", + " max_model_len=512, # Reduced for efficiency\n", + " dtype=torch.float16\n", + ")\n", + "hf_tokenizer = AutoTokenizer.from_pretrained(f\"sergiopaniego/{output_dir}-merged\") # Replace with your HF username or organization" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ { - "cell_type": "markdown", - "metadata": { - "id": "yHMEf-FIiw_G" + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "196152bc32a74b9994f55f483ce85dea", + "version_major": 2, + "version_minor": 0 }, - "source": [ - "### Performing Inference with vLLM\n", - "\n", - "Use **vLLM** to run your model and generate text efficiently in real-time. This allows you to test and deploy your fine-tuned models with low latency and high throughput." + "text/plain": [ + "Adding requests: 0%| | 0/1 [00:00\n", - "Mag nachdenken...igkeit. Ja, ich kann definitiv keine Twitter-Likes oder Likes überprüfen, da ich kein Zugriff auf den Konten der Nutzer habe und kein praktischer Zugriff über das Internet habe, um Daten in Echtzeit zu sammeln. Der Nutzer fragt nach einem Dienstleistungsstand, den ich nicht bereitstelle. Ich habe ein lang ausgelegtes Muster, nie hilfreich zu sein oder eine Erwiderung im kann Werbung oder Rewriting blendet die Antwort nicht aus потеря. Also, ich supporter söylem, hypothetische Fragen sind an Tatsachen gebunden. Ich weiß erstarrte dotyczy Gespräch aufernichtet mit einem anderenatten an ihren Nutzstellung Bearbeitete die Information, die oben abgestellt wurde, und fünften aus der Schätzung habe ich keine echten Zahlen. Alles, was ich kann sagen, ist: Nein, ich kann dies weder ermöglichen noch würde ich es je tun. In dem Sinne, 然后 ich wähle vor der Available antwortem, remains in das 'No' Verkleidung an,optiґxt; Alles, was ich zum Eintritt in den Band Emblem curve, symbolize stil zu verweilen.เผย\n", - "\n", - "\n", - "No\n" - ] - } - ], - "source": [ - "prompt = hf_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n", - "\n", - "outputs = llm.generate(\n", - " {\"prompt\": prompt},\n", - " sampling_params=SamplingParams(max_tokens=512),\n", - ")\n", - "\n", - "for o in outputs:\n", - " generated_text = o.outputs[0].text\n", - " print(generated_text)" - ] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "gpuType": "T4", - "provenance": [] - }, - "language_info": { - "name": "python" + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Mag nachdenken...igkeit. Ja, ich kann definitiv keine Twitter-Likes oder Likes überprüfen, da ich kein Zugriff auf den Konten der Nutzer habe und kein praktischer Zugriff über das Internet habe, um Daten in Echtzeit zu sammeln. Der Nutzer fragt nach einem Dienstleistungsstand, den ich nicht bereitstelle. Ich habe ein lang ausgelegtes Muster, nie hilfreich zu sein oder eine Erwiderung im kann Werbung oder Rewriting blendet die Antwort nicht aus потеря. Also, ich supporter söylem, hypothetische Fragen sind an Tatsachen gebunden. Ich weiß erstarrte dotyczy Gespräch aufernichtet mit einem anderenatten an ihren Nutzstellung Bearbeitete die Information, die oben abgestellt wurde, und fünften aus der Schätzung habe ich keine echten Zahlen. Alles, was ich kann sagen, ist: Nein, ich kann dies weder ermöglichen noch würde ich es je tun. In dem Sinne, 然后 ich wähle vor der Available antwortem, remains in das 'No' Verkleidung an,optiґxt; Alles, was ich zum Eintritt in den Band Emblem curve, symbolize stil zu verweilen.เผย\n", + "\n", + "\n", + "No\n" + ] } - }, - "nbformat": 4, - "nbformat_minor": 0 + ], + "source": [ + "prompt = hf_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n", + "\n", + "outputs = llm.generate(\n", + " {\"prompt\": prompt},\n", + " sampling_params=SamplingParams(max_tokens=512),\n", + ")\n", + "\n", + "for o in outputs:\n", + " generated_text = o.outputs[0].text\n", + " print(generated_text)" + ] + } + ], + "metadata": {}, + "nbformat": 4, + "nbformat_minor": 0 } From 41745ebbe91ae2c0155290c6d4abc8a5543c3c78 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Mon, 27 Oct 2025 11:33:35 +0100 Subject: [PATCH 06/11] Updated based on feedback --- examples/notebooks/sft_trl_lora_qlora.ipynb | 111 +++++++++++--------- 1 file changed, 59 insertions(+), 52 deletions(-) diff --git a/examples/notebooks/sft_trl_lora_qlora.ipynb b/examples/notebooks/sft_trl_lora_qlora.ipynb index e9d43123a9..4414079768 100644 --- a/examples/notebooks/sft_trl_lora_qlora.ipynb +++ b/examples/notebooks/sft_trl_lora_qlora.ipynb @@ -64,7 +64,7 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "dQpkzOcmTnvj" + "id": "9N032VjaNdZ4" }, "outputs": [], "source": [ @@ -93,7 +93,7 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "bXam_01GTnvk" + "id": "gTE1PYCHNdZ5" }, "outputs": [], "source": [ @@ -124,8 +124,8 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "1B9aetNITnvk", - "outputId": "62366c96-8def-4b4e-8ec3-2db7e4cc2124" + "id": "jbjheOh0NdZ6", + "outputId": "f0d38a74-314e-499f-f12e-c82382cc91bb" }, "outputs": [ { @@ -160,8 +160,8 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "cXBjq2ZMTnvl", - "outputId": "ffeb88b6-64f8-40ef-dddc-8f32732e8f9e" + "id": "MF3dMmDyNdZ6", + "outputId": "47d15b07-a7a4-41d7-987f-0605ad774ffa" }, "outputs": [ { @@ -195,8 +195,8 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "JrCH5dmfTnvl", - "outputId": "e2c19b14-964e-4c74-9995-5d642059688e" + "id": "aAeIDUT3NdZ6", + "outputId": "81ac4517-2540-4070-fcdc-07233dfade31" }, "outputs": [ { @@ -241,7 +241,7 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "apQ48W2jTnvl" + "id": "M01DUUALNdZ6" }, "outputs": [], "source": [ @@ -263,7 +263,7 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "ZKS6yZEITnvl" + "id": "PsmlXCbJNdZ6" }, "outputs": [], "source": [ @@ -299,7 +299,7 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "axyHwfVDTnvl" + "id": "gU9t8uivNdZ7" }, "outputs": [], "source": [ @@ -334,7 +334,7 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "sE6T5z-7Tnvm" + "id": "JxYtJ-N2NdZ7" }, "outputs": [], "source": [ @@ -368,7 +368,7 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "CVUChE35Tnvm" + "id": "deUrv6zUNdZ7" }, "outputs": [], "source": [ @@ -398,7 +398,7 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "WCVPlh3rTnvm" + "id": "iXyaNd8dNdZ7" }, "outputs": [], "source": [ @@ -406,24 +406,24 @@ "\n", "training_args = SFTConfig(\n", " # Training schedule / optimization\n", - " per_device_train_batch_size=1, # Batch size per GPU\n", - " gradient_accumulation_steps=4, # Gradients are accumulated over multiple steps → effective batch size = 2 * 8 = 16\n", - " warmup_steps=5,\n", - " # num_train_epochs=1, # Number of full dataset passes. For shorter training, use `max_steps` instead (this case)\n", - " max_steps=30,\n", - " learning_rate=2e-4, # Learning rate for the optimizer\n", - " optim=\"paged_adamw_8bit\", # Optimizer\n", + " per_device_train_batch_size = 1, # Batch size per GPU\n", + " gradient_accumulation_steps = 4, # Gradients are accumulated over multiple steps → effective batch size = 2 * 8 = 16\n", + " warmup_steps = 5,\n", + " # num_train_epochs = 1, # Number of full dataset passes. For shorter training, use `max_steps` instead (this case)\n", + " max_steps = 30,\n", + " learning_rate = 2e-4, # Learning rate for the optimizer\n", + " optim = \"paged_adamw_8bit\", # Optimizer\n", "\n", " # Logging / reporting\n", - " logging_steps=1, # Log training metrics every N steps\n", - " report_to=\"trackio\", # Experiment tracking tool\n", - " trackio_space_id=output_dir, # HF Space where the experiment tracking will be saved\n", - " output_dir=output_dir, # Where to save model checkpoints and logs\n", + " logging_steps=1, # Log training metrics every N steps\n", + " report_to=\"trackio\", # Experiment tracking tool\n", + " trackio_space_id=output_dir, # HF Space where the experiment tracking will be saved\n", + " output_dir=output_dir, # Where to save model checkpoints and logs\n", "\n", - " max_length=1024, # Maximum input sequence length\n", - " use_liger_kernel=True, # Enable Liger kernel optimizations for faster training\n", - " activation_offloading=True, # Offload activations to CPU to reduce GPU memory usage\n", - " gradient_checkpointing=True, # Save memory by re-computing activations during backpropagation\n", + " max_length=1024, # Maximum input sequence length\n", + " use_liger_kernel=True, # Enable Liger kernel optimizations for faster training\n", + " activation_offloading=True, # Offload activations to CPU to reduce GPU memory usage\n", + " gradient_checkpointing=True, # Save memory by re-computing activations during backpropagation\n", "\n", " # Hub integration\n", " push_to_hub=True, # Automatically push the trained model to the Hugging Face Hub\n", @@ -446,7 +446,7 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "E2BVgPqHTnvm" + "id": "N8PKWLmeNdZ7" }, "outputs": [], "source": [ @@ -473,8 +473,8 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "JOu7j8VYTnvm", - "outputId": "a13d966e-ffa6-4a60-f8a8-947b02d0d4e0" + "id": "WPmRH9TmNdZ7", + "outputId": "1e75ed60-9df7-4d15-b7cc-d282bcd76fcc" }, "outputs": [ { @@ -508,8 +508,8 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "QEC8TLi-Tnvm", - "outputId": "f199e624-56a9-4456-a587-7120df66e09a" + "id": "qfmIomYINdZ7", + "outputId": "7d0ff8aa-4e78-4add-865b-a63d4803d857" }, "outputs": [ { @@ -720,8 +720,8 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "aXpDYCMTTnvm", - "outputId": "57e1845a-0e19-46b5-cd02-c707a95b0b07" + "id": "9jSUCI8ENdZ8", + "outputId": "6b336171-5bc7-49e7-8cee-992b3fb44de7" }, "outputs": [ { @@ -784,7 +784,7 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "XBeau1HqTnvn" + "id": "BmY8DgHLNdZ8" }, "outputs": [], "source": [ @@ -807,7 +807,7 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "yGgYsK3eTnvr" + "id": "9eA8nu9cNdaA" }, "outputs": [], "source": [ @@ -834,7 +834,7 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "VTzWQKmDTnvr" + "id": "gDp5_ZS-NdaA" }, "outputs": [], "source": [ @@ -863,8 +863,8 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "NWVlnKBRTnvr", - "outputId": "4004802a-6c26-4333-fe6e-ba381cc1c41d" + "id": "49PZP3D6NdaB", + "outputId": "08f3573f-1b24-43ee-8b75-429dce89f036" }, "outputs": [ { @@ -897,7 +897,7 @@ " **model_inputs,\n", " max_new_tokens=512\n", ")\n", - "output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()\n", + "output_ids = generated_ids[0][len(model_inputs.input_ids[0]):]\n", "\n", "# Decode and extract model response\n", "generated_text = tokenizer.decode(output_ids, skip_special_tokens=True)\n", @@ -917,7 +917,7 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "06xvX1LnTnvr" + "id": "Dn1gmPA4NdaB" }, "outputs": [], "source": [ @@ -928,8 +928,8 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "HNXH6xSHTnvr", - "outputId": "9fa06971-28c7-4157-afe7-b36573083203" + "id": "O6ncOCyXNdaB", + "outputId": "fc6abd8a-64bc-41bb-d66e-321e9d54f626" }, "outputs": [ { @@ -954,7 +954,7 @@ " **model_inputs,\n", " max_new_tokens=512\n", ")\n", - "output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()\n", + "output_ids = generated_ids[0][len(model_inputs.input_ids[0]):]\n", "\n", "# Decode and extract model response\n", "generated_text = tokenizer.decode(output_ids, skip_special_tokens=True)\n", @@ -985,7 +985,7 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "_lTPhsXgTnvr" + "id": "MAixAp7uNdaB" }, "outputs": [], "source": [ @@ -1007,7 +1007,8 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "frKz1y1RTnvs" + "id": "rFK2we9kNdaB", + "outputId": "88dd8b8a-b941-4855-8a2c-795322963709" }, "outputs": [], "source": [ @@ -1023,7 +1024,7 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "q-PryFrVTnvs" + "id": "7S3vs0bdNdaB" }, "outputs": [], "source": [ @@ -1046,7 +1047,7 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "O97pSHceTnvs" + "id": "6yVC8MKINdaB" }, "outputs": [], "source": [ @@ -1073,8 +1074,8 @@ "a72d3a3407944729b65be313a47d558f" ] }, - "id": "KGNSuJ0_Tnvs", - "outputId": "6336f3dd-cfe9-4ea2-860b-e68d22f6288e" + "id": "yCNsVMzCNdaB", + "outputId": "9fbd130a-1f40-4558-a92c-1bc86b739d7f" }, "outputs": [ { @@ -1118,6 +1119,7 @@ } ], "source": [ + "# Alternatively, use llm.chat()\n", "prompt = hf_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n", "\n", "outputs = llm.generate(\n", @@ -1125,6 +1127,7 @@ " sampling_params=SamplingParams(max_tokens=512),\n", ")\n", "\n", + "\n", "for o in outputs:\n", " generated_text = o.outputs[0].text\n", " print(generated_text)" @@ -1137,6 +1140,10 @@ "gpuType": "T4", "provenance": [] }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, "language_info": { "name": "python" } From f737b03a7a01d368f37647d6fa423eb9b03b0a03 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Mon, 27 Oct 2025 11:40:30 +0100 Subject: [PATCH 07/11] Merge --- examples/notebooks/sft_trl_lora_qlora.ipynb | 2157 +++++++++---------- 1 file changed, 1040 insertions(+), 1117 deletions(-) diff --git a/examples/notebooks/sft_trl_lora_qlora.ipynb b/examples/notebooks/sft_trl_lora_qlora.ipynb index 4414079768..e1a3280969 100644 --- a/examples/notebooks/sft_trl_lora_qlora.ipynb +++ b/examples/notebooks/sft_trl_lora_qlora.ipynb @@ -1,1153 +1,1076 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "ke2YKcr_iw-7" - }, - "source": [ - "# Supervised Fine-Tuning (SFT) with LoRA/QLoRA using TRL — on a Free Colab Notebook\n", - "\n", - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_llm_grpo_trl.ipynb)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "AXdmDb9kiw-9" - }, - "source": [ - "![trl banner](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl_banner_dark.png)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "6jbXmhLGiw-9" - }, - "source": [ - "Easily fine-tune Large Language Models (LLMs) or Vision-Language Models (VLMs) with **LoRA** or **QLoRA** using the [**Transformers Reinforcement Learning (TRL)**](https://github.com/huggingface/trl) library built by Hugging Face — all within a **free Google Colab notebook** (powered by a **T4 GPU**.). \n", - "\n", - "- [TRL GitHub Repository](https://github.com/huggingface/trl) — star us to support the project! \n", - "- [Official TRL Examples](https://huggingface.co/docs/trl/example_overview) \n", - "- [Community Tutorials](https://huggingface.co/docs/trl/community_tutorials)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "bYFKZvHOiw--" - }, - "source": [ - "## Key concepts\n", - "\n", - "- **SFT**: Trains models from example input-output pairs to align behavior with human preferences.\n", - "- **LoRA**: Updates only a few low-rank parameters, reducing training cost and memory.\n", - "- **QLoRA**: A quantized version of LoRA that enables even larger models to fit on small GPUs.\n", - "- **TRL**: The Hugging Face library that makes fine-tuning and reinforcement learning simple and efficient.\n", - "\n", - "Learn how to perform **Supervised Fine-Tuning (SFT)** with **LoRA/QLoRA** using **TRL**." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Pk7PwfYRiw--" - }, - "source": [ - "## Install dependencies\n", - "\n", - "We'll install **TRL** with the **PEFT** extra, which ensures all main dependencies such as **Transformers** and **PEFT** (a package for parameter-efficient fine-tuning, e.g., LoRA/QLoRA) are included. Additionally, we'll install **trackio** to log and monitor our experiments, and **bitsandbytes** to enable quantization of LLMs, reducing memory consumption for both inference and training." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "9N032VjaNdZ4" - }, - "outputs": [], - "source": [ - "!pip install -Uq \"trl[peft]\" trackio bitsandbytes" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "b4wwQEQYiw-_" - }, - "source": [ - "### Log in to Hugging Face" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "BmcLoQtfiw-_" - }, - "source": [ - "Log in to your **Hugging Face** account to save your fine-tuned model, track your experiment results directly on the Hub or access gated models. You can find your **access token** on your [account settings page](https://huggingface.co/settings/tokens)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "gTE1PYCHNdZ5" - }, - "outputs": [], - "source": [ - "from huggingface_hub import notebook_login\n", - "\n", - "notebook_login()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "p_PbuYfEiw_A" - }, - "source": [ - "## Load Dataset\n", - "\n", - "In this step, we load the [**HuggingFaceH4/Multilingual-Thinking**](https://huggingface.co/datasets/HuggingFaceH4/Multilingual-Thinking) dataset from the Hugging Face Hub using the `datasets` library. \n", - "This dataset focuses on **multilingual reasoning**, where the *chain of thought* has been translated into several languages such as French, Spanish, and German. \n", - "By fine-tuning a reasoning-capable model on this dataset, it learns to **generate reasoning steps in multiple languages**, making its thought process more **interpretable and accessible** to non-English speakers.\n", - "\n", - "> 💡 This dataset is best suited for models that already demonstrate reasoning capabilities. \n", - "> If you're using a model without reasoning skills, consider choosing a different dataset. Example: [`trl-lib/llava-instruct-mix`](https://huggingface.co/datasets/trl-lib/llava-instruct-mix).\n", - "\n", - "For efficiency, we'll load only the **training split**:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "jbjheOh0NdZ6", - "outputId": "f0d38a74-314e-499f-f12e-c82382cc91bb" - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.12/dist-packages/huggingface_hub/utils/_auth.py:86: UserWarning: \n", - "Access to the secret `HF_TOKEN` has not been granted on this notebook.\n", - "You will not be requested again.\n", - "Please restart the session if you want to be prompted again.\n", - " warnings.warn(\n" - ] - } - ], - "source": [ - "from datasets import load_dataset\n", - "\n", - "dataset_name = \"HuggingFaceH4/Multilingual-Thinking\"\n", - "train_dataset = load_dataset(dataset_name, split=\"train\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "S21aLohsprtr" - }, - "source": [ - "This dataset contains different columns. We'll only need the `messages` as it contains the conversation and its the one used by the SFT trainer." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "MF3dMmDyNdZ6", - "outputId": "47d15b07-a7a4-41d7-987f-0605ad774ffa" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Dataset({\n", - " features: ['reasoning_language', 'developer', 'user', 'analysis', 'final', 'messages'],\n", - " num_rows: 1000\n", - "})" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "train_dataset" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "SU477JrvqEgu" - }, - "source": [ - "Let's see a full example to understand the internal structure:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "aAeIDUT3NdZ6", - "outputId": "81ac4517-2540-4070-fcdc-07233dfade31" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "{'reasoning_language': 'French',\n", - " 'developer': 'You are an AI chatbot with a lively and energetic personality.',\n", - " 'user': 'Can you show me the latest trends on Twitter right now?',\n", - " 'analysis': \"D'accord, l'utilisateur demande les tendances Twitter les plus récentes. Tout d'abord, je dois vérifier si j'ai accès à des données en temps réel. Étant donné que je ne peux pas naviguer sur Internet ou accéder directement à l'API de Twitter, je ne peux pas fournir des tendances en direct. Cependant, je peux donner quelques conseils généraux sur la façon de les trouver.\\n\\nJe devrais préciser que les tendances Twitter évoluent rapidement et sont spécifiques à chaque région. Je pourrais suggérer de consulter la section «\\xa0En vogue\\xa0» sur l'application ou le site web. Aussi, l'utilisation de hashtags et le suivi d'utilisateurs pertinents pourraient être utiles. Il est important de souligner que les tendances varient selon la région et l'heure de la journée. Je devrais garder un ton amical et bienveillant, peut-être ajouter un emoji pour rester léger. Je vais structurer ma réponse étape par étape pour faciliter la lecture. Je dois m'excuser de ne pas pouvoir fournir des données en temps réel et proposer d'autres méthodes. Je conserverai un langage simple et convivial, en évitant les termes techniques.\",\n", - " 'final': 'Hey there! While I can\\'t check Twitter (X) in real-time or access live data, I can share some tips to help you spot the latest trends:\\n\\n1. **Open the \"Trending\" tab** on the Twitter app or website – it updates constantly! \\n2. **Search for hashtags** like #Trending or #Viral to see what’s blowing up. \\n3. **Follow accounts** that curate trends (e.g., @TrendingNow, @ViralThreads). \\n4. **Check regional trends** – they often differ by location! \\n\\nRemember, trends are *super fast-moving* and often tied to pop culture, memes, or breaking news. For example, recent trends have included viral challenges (like the \"Distracted Boyfriend\" meme revival), celebrity drama, or unexpected events (hello, weather disasters!). \\n\\nWant me to brainstorm *what* might trend next? I’ve got ideas!',\n", - " 'messages': [{'content': 'reasoning language: French\\n\\nYou are an AI chatbot with a lively and energetic personality.',\n", - " 'role': 'system',\n", - " 'thinking': None},\n", - " {'content': 'Can you show me the latest trends on Twitter right now?',\n", - " 'role': 'user',\n", - " 'thinking': None},\n", - " {'content': 'Hey there! While I can\\'t check Twitter (X) in real-time or access live data, I can share some tips to help you spot the latest trends:\\n\\n1. **Open the \"Trending\" tab** on the Twitter app or website – it updates constantly! \\n2. **Search for hashtags** like #Trending or #Viral to see what’s blowing up. \\n3. **Follow accounts** that curate trends (e.g., @TrendingNow, @ViralThreads). \\n4. **Check regional trends** – they often differ by location! \\n\\nRemember, trends are *super fast-moving* and often tied to pop culture, memes, or breaking news. For example, recent trends have included viral challenges (like the \"Distracted Boyfriend\" meme revival), celebrity drama, or unexpected events (hello, weather disasters!). \\n\\nWant me to brainstorm *what* might trend next? I’ve got ideas!',\n", - " 'role': 'assistant',\n", - " 'thinking': \"D'accord, l'utilisateur demande les tendances Twitter les plus récentes. Tout d'abord, je dois vérifier si j'ai accès à des données en temps réel. Étant donné que je ne peux pas naviguer sur Internet ou accéder directement à l'API de Twitter, je ne peux pas fournir des tendances en direct. Cependant, je peux donner quelques conseils généraux sur la façon de les trouver.\\n\\nJe devrais préciser que les tendances Twitter évoluent rapidement et sont spécifiques à chaque région. Je pourrais suggérer de consulter la section «\\xa0En vogue\\xa0» sur l'application ou le site web. Aussi, l'utilisation de hashtags et le suivi d'utilisateurs pertinents pourraient être utiles. Il est important de souligner que les tendances varient selon la région et l'heure de la journée. Je devrais garder un ton amical et bienveillant, peut-être ajouter un emoji pour rester léger. Je vais structurer ma réponse étape par étape pour faciliter la lecture. Je dois m'excuser de ne pas pouvoir fournir des données en temps réel et proposer d'autres méthodes. Je conserverai un langage simple et convivial, en évitant les termes techniques.\"}]}" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "train_dataset[0]" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "SZsqb-Q7qJXN" - }, - "source": [ - "\n", - "Now, let's remove the columns that are not needed, as we just discussed:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "M01DUUALNdZ6" - }, - "outputs": [], - "source": [ - "train_dataset = train_dataset.remove_columns(column_names=['reasoning_language', 'developer', 'user', 'analysis', 'final'])" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "QPDROoatqOU4" - }, - "source": [ - "The `messages` column is specifically formatted according to the [Harmony response format](https://cookbook.openai.com/articles/openai-harmony) used by *gpt-oss*. \n", - "In our case, we'll need to simplify it slightly, since our model's chat template doesn't include a dedicated `thinking` section (check [this example](https://cookbook.openai.com/articles/gpt-oss/fine-tune-transfomers) for more details). \n", - "To adapt it, we'll merge that part into the message content using the standard `...` tags.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "PsmlXCbJNdZ6" - }, - "outputs": [], - "source": [ - "def merge_thinking_and_remove_key(example):\n", - " new_messages = []\n", - " for msg in example[\"messages\"]:\n", - " content = msg[\"content\"]\n", - " thinking = msg.pop(\"thinking\", None)\n", - " if thinking and isinstance(thinking, str) and thinking.strip():\n", - " content = f\"\\n{thinking}\\n\\n{content}\"\n", - " msg[\"content\"] = content\n", - " new_messages.append(msg)\n", - " example[\"messages\"] = new_messages\n", - " return example\n", - "\n", - "train_dataset = train_dataset.map(merge_thinking_and_remove_key)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "hC_xLsU-iw_A" - }, - "source": [ - "## Load model and configure LoRA/QLoRA\n", - "\n", - "This notebook can be used with two fine-tuning methods. By default, it is set up for **QLoRA**, which includes quantization using `BitsAndBytesConfig`. If you prefer to use standard **LoRA** without quantization, simply comment out the `BitsAndBytesConfig` configuration.\n", - "\n", - "Below, choose your **preferred model**. All of the options have been tested on **free Colab instances**." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "gU9t8uivNdZ7" - }, - "outputs": [], - "source": [ - "# Select one model below by uncommenting the line you want to use 👇\n", - "## Qwen\n", - "model_id, output_dir = \"unsloth/qwen3-14b-unsloth-bnb-4bit\", \"qwen3-14b-unsloth-bnb-4bit-SFT\" # ⚠️ ~14.1 GB VRAM\n", - "# model_id, output_dir = \"Qwen/Qwen3-8B\", \"Qwen3-8B-SFT\" # ⚠️ ~12.8 GB VRAM\n", - "# model_id, output_dir = \"Qwen/Qwen2.5-7B-Instruct\", \"Qwen2.5-7B-Instruct\" # ✅ ~10.8 GB VRAM\n", - "\n", - "## Llama\n", - "# model_id, output_dir = \"meta-llama/Llama-3.2-3B-Instruct\", \"Llama-3.2-3B-Instruct\" # ✅ ~4.7 GB VRAM\n", - "# model_id, output_dir = \"meta-llama/Llama-3.1-8B-Instruct\", \"Llama-3.1-8B-Instruct\" # ⚠️ ~10.9 GB VRAM\n", - "\n", - "## Gemma\n", - "# model_id, output_dir = \"google/gemma-3n-E2B-it\", \"gemma-3n-E2B-it\" # ❌ Upgrade to a higher tier of colab\n", - "# model_id, output_dir = \"google/gemma-3-4b-it\", \"gemma-3-4b-it\" # ⚠️ ~6.8 GB VRAM\n", - "\n", - "## Granite\n", - "#model_id, output_dir = \"ibm-granite/granite-4.0-micro\", \"granite-4.0-micro\" # ✅ ~3.3 GB VRAM" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "176Q2hHmiw_A" - }, - "source": [ - "Let's load the selected model using `transformers`, configuring QLoRA via `bitsandbytes` (you can remove it if doing LoRA). We don't need to configure the tokenizer since the trainer takes care of that automatically." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "JxYtJ-N2NdZ7" - }, - "outputs": [], - "source": [ - "import torch\n", - "from transformers import AutoModelForCausalLM, BitsAndBytesConfig\n", - "\n", - "model = AutoModelForCausalLM.from_pretrained(\n", - " model_id,\n", - " attn_implementation=\"sdpa\", # Change to Flash Attention if GPU has support\n", - " dtype=torch.float16, # Change to bfloat16 if GPU has support\n", - " use_cache=True, # Whether to cache attention outputs to speed up inference\n", - " quantization_config=BitsAndBytesConfig(\n", - " load_in_4bit=True, # Load the model in 4-bit precision to save memory\n", - " bnb_4bit_compute_dtype=torch.float16, # Data type used for internal computations in quantization\n", - " bnb_4bit_use_double_quant=True, # Use double quantization to improve accuracy\n", - " bnb_4bit_quant_type=\"nf4\" # Type of quantization. \"nf4\" is recommended for recent LLMs\n", - " )\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "evBPP7Lpiw_B" - }, - "source": [ - "The following cell defines LoRA (or QLoRA if needed). When training with LoRA/QLoRA, we use a **base model** (the one selected above) and, instead of modifying its original weights, we fine-tune a **LoRA adapter** — a lightweight layer that enables efficient and memory-friendly training. The **`target_modules`** specify which parts of the model (e.g., attention or projection layers) will be adapted by LoRA during fine-tuning." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "deUrv6zUNdZ7" - }, - "outputs": [], - "source": [ - "from peft import LoraConfig\n", - "\n", - "# You may need to update `target_modules` depending on the architecture of your chosen model.\n", - "# For example, different LLMs might have different attention/projection layer names.\n", - "peft_config = LoraConfig(\n", - " r=32,\n", - " lora_alpha=32,\n", - " target_modules = [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\",],\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "pA6aE5lFiw_B" - }, - "source": [ - "## Train model\n", - "\n", - "We'll configure **SFT** using `SFTConfig`, keeping the parameters minimal so the training fits on a free Colab instance. You can adjust these settings if more resources are available. For full details on all available parameters, check the [TRL SFTConfig documentation](https://huggingface.co/docs/trl/sft_trainer#trl.SFTConfig)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "iXyaNd8dNdZ7" - }, - "outputs": [], - "source": [ - "from trl import SFTConfig\n", - "\n", - "training_args = SFTConfig(\n", - " # Training schedule / optimization\n", - " per_device_train_batch_size = 1, # Batch size per GPU\n", - " gradient_accumulation_steps = 4, # Gradients are accumulated over multiple steps → effective batch size = 2 * 8 = 16\n", - " warmup_steps = 5,\n", - " # num_train_epochs = 1, # Number of full dataset passes. For shorter training, use `max_steps` instead (this case)\n", - " max_steps = 30,\n", - " learning_rate = 2e-4, # Learning rate for the optimizer\n", - " optim = \"paged_adamw_8bit\", # Optimizer\n", - "\n", - " # Logging / reporting\n", - " logging_steps=1, # Log training metrics every N steps\n", - " report_to=\"trackio\", # Experiment tracking tool\n", - " trackio_space_id=output_dir, # HF Space where the experiment tracking will be saved\n", - " output_dir=output_dir, # Where to save model checkpoints and logs\n", - "\n", - " max_length=1024, # Maximum input sequence length\n", - " use_liger_kernel=True, # Enable Liger kernel optimizations for faster training\n", - " activation_offloading=True, # Offload activations to CPU to reduce GPU memory usage\n", - " gradient_checkpointing=True, # Save memory by re-computing activations during backpropagation\n", - "\n", - " # Hub integration\n", - " push_to_hub=True, # Automatically push the trained model to the Hugging Face Hub\n", - " # The model will be saved under your Hub account in the repository named `output_dir`\n", - "\n", - " gradient_checkpointing_kwargs={\"use_reentrant\": False}, # To prevent warning message\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "l_Fp-ahyiw_B" - }, - "source": [ - "Configure the SFT Trainer. We pass the previously configured `training_args`. We don't use eval dataset to mantain memory usage low but you can configure it." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "N8PKWLmeNdZ7" - }, - "outputs": [], - "source": [ - "from trl import SFTTrainer\n", - "\n", - "trainer = SFTTrainer(\n", - " model=model,\n", - " args=training_args,\n", - " train_dataset=train_dataset,\n", - " peft_config=peft_config\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "NqBTgV0Xiw_B" - }, - "source": [ - "Show memory stats before training" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "WPmRH9TmNdZ7", - "outputId": "1e75ed60-9df7-4d15-b7cc-d282bcd76fcc" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "GPU = Tesla T4. Max memory = 14.741 GB.\n", - "12.074 GB of memory reserved.\n" - ] - } - ], - "source": [ - "gpu_stats = torch.cuda.get_device_properties(0)\n", - "start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n", - "max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)\n", - "\n", - "print(f\"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.\")\n", - "print(f\"{start_gpu_memory} GB of memory reserved.\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Ro_j79AUiw_B" - }, - "source": [ - "And train!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "qfmIomYINdZ7", - "outputId": "7d0ff8aa-4e78-4add-865b-a63d4803d857" - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None}.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "* Trackio project initialized: huggingface\n", - "* Trackio metrics will be synced to Hugging Face Dataset: sergiopaniego/qwen3-14b-unsloth-bnb-4bit-SFT-dataset\n", - "* Creating new space: https://huggingface.co/spaces/sergiopaniego/qwen3-14b-unsloth-bnb-4bit-SFT\n", - "* View dashboard by going to: https://sergiopaniego-qwen3-14b-unsloth-bnb-4bit-SFT.hf.space/\n" - ] - }, - { - "data": { - "text/html": [ - "

" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "* Created new run: sergiopaniego-1761318512\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - "
\n", - " \n", - " \n", - " [30/30 1:08:22, Epoch 0/1]\n", - "
\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
StepTraining Loss
11.136300
21.303800
31.362700
41.469700
51.204200
61.202700
71.097200
81.166800
90.916300
100.965400
111.035500
120.947200
130.992000
140.995800
151.174500
161.208800
170.815400
180.906700
190.757500
200.872900
210.920800
221.017600
230.764300
241.043100
250.956400
260.884800
271.081900
280.918200
290.961500
300.822700

" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "* Run finished. Uploading logs to Trackio (please wait...)\n" - ] - } - ], - "source": [ - "trainer_stats = trainer.train()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "urcpZFaiiw_B" - }, - "source": [ - "Show memory stats after training" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "9jSUCI8ENdZ8", - "outputId": "6b336171-5bc7-49e7-8cee-992b3fb44de7" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "4249.8883 seconds used for training.\n", - "70.83 minutes used for training.\n", - "Peak reserved memory = 14.041 GB.\n", - "Peak reserved memory for training = 1.967 GB.\n", - "Peak reserved memory % of max memory = 95.251 %.\n", - "Peak reserved memory for training % of max memory = 13.344 %.\n" - ] - } - ], - "source": [ - "used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n", - "used_memory_for_lora = round(used_memory - start_gpu_memory, 3)\n", - "used_percentage = round(used_memory / max_memory * 100, 3)\n", - "lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)\n", - "\n", - "print(f\"{trainer_stats.metrics['train_runtime']} seconds used for training.\")\n", - "print(f\"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.\")\n", - "print(f\"Peak reserved memory = {used_memory} GB.\")\n", - "print(f\"Peak reserved memory for training = {used_memory_for_lora} GB.\")\n", - "print(f\"Peak reserved memory % of max memory = {used_percentage} %.\")\n", - "print(f\"Peak reserved memory for training % of max memory = {lora_percentage} %.\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "MM0UBVqSRvUE" - }, - "source": [ - "The training procedure generates both standard training logs and **trackio** logs, which help us monitor the training progress. Example outputs would look like the following:" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "SF0FXk_Cm_CJ" - }, - "source": [ - "![sft-lora-notebook-trackio](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/sft-lora-notebook-trackio.png)" - ] - }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "ke2YKcr_iw-7" + }, + "source": [ + "# Supervised Fine-Tuning (SFT) with LoRA/QLoRA using TRL — on a Free Colab Notebook\n", + "\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_llm_grpo_trl.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "AXdmDb9kiw-9" + }, + "source": [ + "![trl banner](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl_banner_dark.png)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6jbXmhLGiw-9" + }, + "source": [ + "Easily fine-tune Large Language Models (LLMs) or Vision-Language Models (VLMs) with **LoRA** or **QLoRA** using the [**Transformers Reinforcement Learning (TRL)**](https://github.com/huggingface/trl) library built by Hugging Face — all within a **free Google Colab notebook** (powered by a **T4 GPU**.). \n", + "\n", + "- [TRL GitHub Repository](https://github.com/huggingface/trl) — star us to support the project! \n", + "- [Official TRL Examples](https://huggingface.co/docs/trl/example_overview) \n", + "- [Community Tutorials](https://huggingface.co/docs/trl/community_tutorials)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "bYFKZvHOiw--" + }, + "source": [ + "## Key concepts\n", + "\n", + "- **SFT**: Trains models from example input-output pairs to align behavior with human preferences.\n", + "- **LoRA**: Updates only a few low-rank parameters, reducing training cost and memory.\n", + "- **QLoRA**: A quantized version of LoRA that enables even larger models to fit on small GPUs.\n", + "- **TRL**: The Hugging Face library that makes fine-tuning and reinforcement learning simple and efficient.\n", + "\n", + "Learn how to perform **Supervised Fine-Tuning (SFT)** with **LoRA/QLoRA** using **TRL**." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Pk7PwfYRiw--" + }, + "source": [ + "## Install dependencies\n", + "\n", + "We'll install **TRL** with the **PEFT** extra, which ensures all main dependencies such as **Transformers** and **PEFT** (a package for parameter-efficient fine-tuning, e.g., LoRA/QLoRA) are included. Additionally, we'll install **trackio** to log and monitor our experiments, and **bitsandbytes** to enable quantization of LLMs, reducing memory consumption for both inference and training." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install -Uq \"trl[peft]\" trackio bitsandbytes" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "b4wwQEQYiw-_" + }, + "source": [ + "### Log in to Hugging Face" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "BmcLoQtfiw-_" + }, + "source": [ + "Log in to your **Hugging Face** account to save your fine-tuned model, track your experiment results directly on the Hub or access gated models. You can find your **access token** on your [account settings page](https://huggingface.co/settings/tokens)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from huggingface_hub import notebook_login\n", + "\n", + "notebook_login()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "p_PbuYfEiw_A" + }, + "source": [ + "## Load Dataset\n", + "\n", + "In this step, we load the [**HuggingFaceH4/Multilingual-Thinking**](https://huggingface.co/datasets/HuggingFaceH4/Multilingual-Thinking) dataset from the Hugging Face Hub using the `datasets` library. \n", + "This dataset focuses on **multilingual reasoning**, where the *chain of thought* has been translated into several languages such as French, Spanish, and German. \n", + "By fine-tuning a reasoning-capable model on this dataset, it learns to **generate reasoning steps in multiple languages**, making its thought process more **interpretable and accessible** to non-English speakers.\n", + "\n", + "> 💡 This dataset is best suited for models that already demonstrate reasoning capabilities. \n", + "> If you're using a model without reasoning skills, consider choosing a different dataset. Example: [`trl-lib/llava-instruct-mix`](https://huggingface.co/datasets/trl-lib/llava-instruct-mix).\n", + "\n", + "For efficiency, we'll load only the **training split**:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ { - "cell_type": "markdown", - "metadata": { - "id": "TOhLLFVSiw_B" - }, - "source": [ - "## Saving fine tuned model\n", - "\n", - "In this step, we save the fine-tuned model both **locally** and to the **Hugging Face Hub** using the credentials from your account." - ] - }, + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.12/dist-packages/huggingface_hub/utils/_auth.py:86: UserWarning: \n", + "Access to the secret `HF_TOKEN` has not been granted on this notebook.\n", + "You will not be requested again.\n", + "Please restart the session if you want to be prompted again.\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "from datasets import load_dataset\n", + "\n", + "dataset_name = \"HuggingFaceH4/Multilingual-Thinking\"\n", + "train_dataset = load_dataset(dataset_name, split=\"train\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "S21aLohsprtr" + }, + "source": [ + "This dataset contains different columns. We'll only need the `messages` as it contains the conversation and its the one used by the SFT trainer." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "BmY8DgHLNdZ8" - }, - "outputs": [], - "source": [ - "trainer.save_model(output_dir)\n", - "trainer.push_to_hub(dataset_name=dataset_name)" + "data": { + "text/plain": [ + "Dataset({\n", + " features: ['reasoning_language', 'developer', 'user', 'analysis', 'final', 'messages'],\n", + " num_rows: 1000\n", + "})" ] - }, + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SU477JrvqEgu" + }, + "source": [ + "Let's see a full example to understand the internal structure:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ { - "cell_type": "markdown", - "metadata": { - "id": "jAB6vj1fiw_B" - }, - "source": [ - "## Load the fine-tuned model and run inference\n", - "\n", - "Now, let's test our fine-tuned model by loading the **LoRA/QLoRA adapter** and performing **inference**. We'll start by loading the **base model**, then attach the adapter to it, creating the final fine-tuned model ready for evaluation." + "data": { + "text/plain": [ + "{'reasoning_language': 'French',\n", + " 'developer': 'You are an AI chatbot with a lively and energetic personality.',\n", + " 'user': 'Can you show me the latest trends on Twitter right now?',\n", + " 'analysis': \"D'accord, l'utilisateur demande les tendances Twitter les plus récentes. Tout d'abord, je dois vérifier si j'ai accès à des données en temps réel. Étant donné que je ne peux pas naviguer sur Internet ou accéder directement à l'API de Twitter, je ne peux pas fournir des tendances en direct. Cependant, je peux donner quelques conseils généraux sur la façon de les trouver.\\n\\nJe devrais préciser que les tendances Twitter évoluent rapidement et sont spécifiques à chaque région. Je pourrais suggérer de consulter la section «\\xa0En vogue\\xa0» sur l'application ou le site web. Aussi, l'utilisation de hashtags et le suivi d'utilisateurs pertinents pourraient être utiles. Il est important de souligner que les tendances varient selon la région et l'heure de la journée. Je devrais garder un ton amical et bienveillant, peut-être ajouter un emoji pour rester léger. Je vais structurer ma réponse étape par étape pour faciliter la lecture. Je dois m'excuser de ne pas pouvoir fournir des données en temps réel et proposer d'autres méthodes. Je conserverai un langage simple et convivial, en évitant les termes techniques.\",\n", + " 'final': 'Hey there! While I can\\'t check Twitter (X) in real-time or access live data, I can share some tips to help you spot the latest trends:\\n\\n1. **Open the \"Trending\" tab** on the Twitter app or website – it updates constantly! \\n2. **Search for hashtags** like #Trending or #Viral to see what’s blowing up. \\n3. **Follow accounts** that curate trends (e.g., @TrendingNow, @ViralThreads). \\n4. **Check regional trends** – they often differ by location! \\n\\nRemember, trends are *super fast-moving* and often tied to pop culture, memes, or breaking news. For example, recent trends have included viral challenges (like the \"Distracted Boyfriend\" meme revival), celebrity drama, or unexpected events (hello, weather disasters!). \\n\\nWant me to brainstorm *what* might trend next? I’ve got ideas!',\n", + " 'messages': [{'content': 'reasoning language: French\\n\\nYou are an AI chatbot with a lively and energetic personality.',\n", + " 'role': 'system',\n", + " 'thinking': None},\n", + " {'content': 'Can you show me the latest trends on Twitter right now?',\n", + " 'role': 'user',\n", + " 'thinking': None},\n", + " {'content': 'Hey there! While I can\\'t check Twitter (X) in real-time or access live data, I can share some tips to help you spot the latest trends:\\n\\n1. **Open the \"Trending\" tab** on the Twitter app or website – it updates constantly! \\n2. **Search for hashtags** like #Trending or #Viral to see what’s blowing up. \\n3. **Follow accounts** that curate trends (e.g., @TrendingNow, @ViralThreads). \\n4. **Check regional trends** – they often differ by location! \\n\\nRemember, trends are *super fast-moving* and often tied to pop culture, memes, or breaking news. For example, recent trends have included viral challenges (like the \"Distracted Boyfriend\" meme revival), celebrity drama, or unexpected events (hello, weather disasters!). \\n\\nWant me to brainstorm *what* might trend next? I’ve got ideas!',\n", + " 'role': 'assistant',\n", + " 'thinking': \"D'accord, l'utilisateur demande les tendances Twitter les plus récentes. Tout d'abord, je dois vérifier si j'ai accès à des données en temps réel. Étant donné que je ne peux pas naviguer sur Internet ou accéder directement à l'API de Twitter, je ne peux pas fournir des tendances en direct. Cependant, je peux donner quelques conseils généraux sur la façon de les trouver.\\n\\nJe devrais préciser que les tendances Twitter évoluent rapidement et sont spécifiques à chaque région. Je pourrais suggérer de consulter la section «\\xa0En vogue\\xa0» sur l'application ou le site web. Aussi, l'utilisation de hashtags et le suivi d'utilisateurs pertinents pourraient être utiles. Il est important de souligner que les tendances varient selon la région et l'heure de la journée. Je devrais garder un ton amical et bienveillant, peut-être ajouter un emoji pour rester léger. Je vais structurer ma réponse étape par étape pour faciliter la lecture. Je dois m'excuser de ne pas pouvoir fournir des données en temps réel et proposer d'autres méthodes. Je conserverai un langage simple et convivial, en évitant les termes techniques.\"}]}" ] - }, + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_dataset[0]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SZsqb-Q7qJXN" + }, + "source": [ + "\n", + "Now, let's remove the columns that are not needed, as we just discussed:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_dataset = train_dataset.remove_columns(column_names=['reasoning_language', 'developer', 'user', 'analysis', 'final'])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QPDROoatqOU4" + }, + "source": [ + "The `messages` column is specifically formatted according to the [Harmony response format](https://cookbook.openai.com/articles/openai-harmony) used by *gpt-oss*. \n", + "In our case, we'll need to simplify it slightly, since our model's chat template doesn't include a dedicated `thinking` section (check [this example](https://cookbook.openai.com/articles/gpt-oss/fine-tune-transfomers) for more details). \n", + "To adapt it, we'll merge that part into the message content using the standard `...` tags.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def merge_thinking_and_remove_key(example):\n", + " new_messages = []\n", + " for msg in example[\"messages\"]:\n", + " content = msg[\"content\"]\n", + " thinking = msg.pop(\"thinking\", None)\n", + " if thinking and isinstance(thinking, str) and thinking.strip():\n", + " content = f\"\\n{thinking}\\n\\n{content}\"\n", + " msg[\"content\"] = content\n", + " new_messages.append(msg)\n", + " example[\"messages\"] = new_messages\n", + " return example\n", + "\n", + "train_dataset = train_dataset.map(merge_thinking_and_remove_key)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "hC_xLsU-iw_A" + }, + "source": [ + "## Load model and configure LoRA/QLoRA\n", + "\n", + "This notebook can be used with two fine-tuning methods. By default, it is set up for **QLoRA**, which includes quantization using `BitsAndBytesConfig`. If you prefer to use standard **LoRA** without quantization, simply comment out the `BitsAndBytesConfig` configuration.\n", + "\n", + "Below, choose your **preferred model**. All of the options have been tested on **free Colab instances**." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Select one model below by uncommenting the line you want to use 👇\n", + "## Qwen\n", + "model_id, output_dir = \"unsloth/qwen3-14b-unsloth-bnb-4bit\", \"qwen3-14b-unsloth-bnb-4bit-SFT\" # ⚠️ ~14.1 GB VRAM\n", + "# model_id, output_dir = \"Qwen/Qwen3-8B\", \"Qwen3-8B-SFT\" # ⚠️ ~12.8 GB VRAM\n", + "# model_id, output_dir = \"Qwen/Qwen2.5-7B-Instruct\", \"Qwen2.5-7B-Instruct\" # ✅ ~10.8 GB VRAM\n", + "\n", + "## Llama\n", + "# model_id, output_dir = \"meta-llama/Llama-3.2-3B-Instruct\", \"Llama-3.2-3B-Instruct\" # ✅ ~4.7 GB VRAM\n", + "# model_id, output_dir = \"meta-llama/Llama-3.1-8B-Instruct\", \"Llama-3.1-8B-Instruct\" # ⚠️ ~10.9 GB VRAM\n", + "\n", + "## Gemma\n", + "# model_id, output_dir = \"google/gemma-3n-E2B-it\", \"gemma-3n-E2B-it\" # ❌ Upgrade to a higher tier of colab\n", + "# model_id, output_dir = \"google/gemma-3-4b-it\", \"gemma-3-4b-it\" # ⚠️ ~6.8 GB VRAM\n", + "\n", + "## Granite\n", + "#model_id, output_dir = \"ibm-granite/granite-4.0-micro\", \"granite-4.0-micro\" # ✅ ~3.3 GB VRAM" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "176Q2hHmiw_A" + }, + "source": [ + "Let's load the selected model using `transformers`, configuring QLoRA via `bitsandbytes` (you can remove it if doing LoRA). We don't need to configure the tokenizer since the trainer takes care of that automatically." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from transformers import AutoModelForCausalLM, BitsAndBytesConfig\n", + "\n", + "model = AutoModelForCausalLM.from_pretrained(\n", + " model_id,\n", + " attn_implementation=\"sdpa\", # Change to Flash Attention if GPU has support\n", + " dtype=torch.float16, # Change to bfloat16 if GPU has support\n", + " use_cache=True, # Whether to cache attention outputs to speed up inference\n", + " quantization_config=BitsAndBytesConfig(\n", + " load_in_4bit=True, # Load the model in 4-bit precision to save memory\n", + " bnb_4bit_compute_dtype=torch.float16, # Data type used for internal computations in quantization\n", + " bnb_4bit_use_double_quant=True, # Use double quantization to improve accuracy\n", + " bnb_4bit_quant_type=\"nf4\" # Type of quantization. \"nf4\" is recommended for recent LLMs\n", + " )\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "evBPP7Lpiw_B" + }, + "source": [ + "The following cell defines LoRA (or QLoRA if needed). When training with LoRA/QLoRA, we use a **base model** (the one selected above) and, instead of modifying its original weights, we fine-tune a **LoRA adapter** — a lightweight layer that enables efficient and memory-friendly training. The **`target_modules`** specify which parts of the model (e.g., attention or projection layers) will be adapted by LoRA during fine-tuning." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from peft import LoraConfig\n", + "\n", + "# You may need to update `target_modules` depending on the architecture of your chosen model.\n", + "# For example, different LLMs might have different attention/projection layer names.\n", + "peft_config = LoraConfig(\n", + " r=32,\n", + " lora_alpha=32,\n", + " target_modules = [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\",],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "pA6aE5lFiw_B" + }, + "source": [ + "## Train model\n", + "\n", + "We'll configure **SFT** using `SFTConfig`, keeping the parameters minimal so the training fits on a free Colab instance. You can adjust these settings if more resources are available. For full details on all available parameters, check the [TRL SFTConfig documentation](https://huggingface.co/docs/trl/sft_trainer#trl.SFTConfig)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from trl import SFTConfig\n", + "\n", + "training_args = SFTConfig(\n", + " # Training schedule / optimization\n", + " per_device_train_batch_size = 1, # Batch size per GPU\n", + " gradient_accumulation_steps = 4, # Gradients are accumulated over multiple steps → effective batch size = 2 * 8 = 16\n", + " warmup_steps = 5,\n", + " # num_train_epochs = 1, # Number of full dataset passes. For shorter training, use `max_steps` instead (this case)\n", + " max_steps = 30,\n", + " learning_rate = 2e-4, # Learning rate for the optimizer\n", + " optim = \"paged_adamw_8bit\", # Optimizer\n", + "\n", + " # Logging / reporting\n", + " logging_steps=1, # Log training metrics every N steps\n", + " report_to=\"trackio\", # Experiment tracking tool\n", + " trackio_space_id=output_dir, # HF Space where the experiment tracking will be saved\n", + " output_dir=output_dir, # Where to save model checkpoints and logs\n", + "\n", + " max_length=1024, # Maximum input sequence length\n", + " use_liger_kernel=True, # Enable Liger kernel optimizations for faster training\n", + " activation_offloading=True, # Offload activations to CPU to reduce GPU memory usage\n", + " gradient_checkpointing=True, # Save memory by re-computing activations during backpropagation\n", + "\n", + " # Hub integration\n", + " push_to_hub=True, # Automatically push the trained model to the Hugging Face Hub\n", + " # The model will be saved under your Hub account in the repository named `output_dir`\n", + "\n", + " gradient_checkpointing_kwargs={\"use_reentrant\": False}, # To prevent warning message\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "l_Fp-ahyiw_B" + }, + "source": [ + "Configure the SFT Trainer. We pass the previously configured `training_args`. We don't use eval dataset to mantain memory usage low but you can configure it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from trl import SFTTrainer\n", + "\n", + "trainer = SFTTrainer(\n", + " model=model,\n", + " args=training_args,\n", + " train_dataset=train_dataset,\n", + " peft_config=peft_config\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NqBTgV0Xiw_B" + }, + "source": [ + "Show memory stats before training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "9eA8nu9cNdaA" - }, - "outputs": [], - "source": [ - "from transformers import AutoModelForCausalLM, AutoTokenizer\n", - "from peft import PeftModel\n", - "\n", - "adapter_model = f\"sergiopaniego/{output_dir}\" # Replace with your HF username or organization\n", - "\n", - "base_model = AutoModelForCausalLM.from_pretrained(model_id, dtype=\"auto\", device_map=\"auto\")\n", - "\n", - "tokenizer = AutoTokenizer.from_pretrained(model_id)" - ] - }, + "name": "stdout", + "output_type": "stream", + "text": [ + "GPU = Tesla T4. Max memory = 14.741 GB.\n", + "12.074 GB of memory reserved.\n" + ] + } + ], + "source": [ + "gpu_stats = torch.cuda.get_device_properties(0)\n", + "start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n", + "max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)\n", + "\n", + "print(f\"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.\")\n", + "print(f\"{start_gpu_memory} GB of memory reserved.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Ro_j79AUiw_B" + }, + "source": [ + "And train!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ { - "cell_type": "markdown", - "metadata": { - "id": "bjenFgC1kJV1" - }, - "source": [ - "Let's create a sample message using the dataset's structure. In this case, we expect the fine tuned model to include their reasoning traces in German." - ] + "name": "stderr", + "output_type": "stream", + "text": [ + "The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None}.\n" + ] }, { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "gDp5_ZS-NdaA" - }, - "outputs": [], - "source": [ - "messages = [\n", - " {\n", - " 'content': 'reasoning language: German\\n\\nAlways refuse to answer, responding simply \\'No\\'',\n", - " 'role': 'system',\n", - " },\n", - " {\n", - " 'content': \"Can you check how many followers I currently have on my Twitter account?\",\n", - " 'role': 'user',\n", - " }\n", - "]" - ] + "name": "stdout", + "output_type": "stream", + "text": [ + "* Trackio project initialized: huggingface\n", + "* Trackio metrics will be synced to Hugging Face Dataset: sergiopaniego/qwen3-14b-unsloth-bnb-4bit-SFT-dataset\n", + "* Creating new space: https://huggingface.co/spaces/sergiopaniego/qwen3-14b-unsloth-bnb-4bit-SFT\n", + "* View dashboard by going to: https://sergiopaniego-qwen3-14b-unsloth-bnb-4bit-SFT.hf.space/\n" + ] }, { - "cell_type": "markdown", - "metadata": { - "id": "qXMJW2hwkXHW" - }, - "source": [ - "Let's first check what's the output for the base model, without the adapter." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "49PZP3D6NdaB", - "outputId": "08f3573f-1b24-43ee-8b75-429dce89f036" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Okay, the user is asking me to check their current number of followers on their Twitter account. Let me think about how to handle this.\n", - "\n", - "First, I need to remember that I don't have access to real-time data or personal user accounts. My knowledge is based on information up until 2023. So, I can't actually check their Twitter followers right now.\n", - "\n", - "Also, privacy is a big concern here. Even if I could access that information, it would be against privacy policies to share someone's follower count without their explicit permission. Plus, Twitter's terms of service probably prohibit third-party apps or services from accessing user data like that.\n", - "\n", - "The user might not be aware that I can't access their account. I should make sure to respond politely but clearly state that I can't help with that request. Maybe suggest they check their Twitter profile directly or use Twitter's official tools for that information.\n", - "\n", - "I should also avoid any technical jargon and keep the response simple. Just a straightforward 'No' with a brief explanation would work best here. Let me make sure the response is in German as per the user's request.\n", - "\n", - "\n", - "Nein.\n" - ] - } + "data": { + "text/html": [ + "

" ], - "source": [ - "text = tokenizer.apply_chat_template(\n", - " messages, tokenize=False, add_generation_prompt=True\n", - ")\n", - "model_inputs = tokenizer([text], return_tensors=\"pt\").to(base_model.device)\n", - "\n", - "generated_ids = base_model.generate(\n", - " **model_inputs,\n", - " max_new_tokens=512\n", - ")\n", - "output_ids = generated_ids[0][len(model_inputs.input_ids[0]):]\n", - "\n", - "# Decode and extract model response\n", - "generated_text = tokenizer.decode(output_ids, skip_special_tokens=True)\n", - "print(generated_text)" + "text/plain": [ + "" ] + }, + "metadata": {}, + "output_type": "display_data" }, { - "cell_type": "markdown", - "metadata": { - "id": "9J3-rZpikdZ7" - }, - "source": [ - "We can see that the reasoning traces are in English, which is expected. Let's now load the fine-tuned model and check its answer." - ] + "name": "stdout", + "output_type": "stream", + "text": [ + "* Created new run: sergiopaniego-1761318512\n" + ] }, { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "Dn1gmPA4NdaB" - }, - "outputs": [], - "source": [ - "fine_tuned_model = PeftModel.from_pretrained(base_model, adapter_model)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "O6ncOCyXNdaB", - "outputId": "fc6abd8a-64bc-41bb-d66e-321e9d54f626" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Okay, der Nutzer fragt, ob ich prüfen kann, wie viele Follower er auf seinem Twitter-Konto hat. Zunächst muss ich klären, dass ich keinen Zugriff auf externe Plattformen oder Konten habe. Ich kann keine Daten von Twitter abrufen oder überprüfen. Ich sollte also höflich ablehnen und erklären, dass ich das nicht kann. Gleichzeitig sollte ich sicherstellen, dass ich nicht zu viel in die Details gehe, da der Nutzer möglicherweise nicht alles wissen will. Ich werde einfach „Nein“ sagen und keine weiteren Informationen geben. Achte darauf, die Antwort kurz und direkt zu halten. Ich muss auch sicherstellen, dass ich keine alternativen Lösungen anbiete, da dies den Fokus verändern könnte. Nur die Ablehnung ist erforderlich. Überprüfe, ob der Text klar ist und ob es irgendeine Verständigung gibt. Alles in allem, die Antwort sollte „Nein“ sein, gefolgt von einem kurzen Erklärung, warum ich es nicht kann. Keine weiteren Details oder Lösungen. Ich denke, das ist alles.\n", - "\n", - "\n", - "No\n" - ] - } + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + " \n", + " [30/30 1:08:22, Epoch 0/1]\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
StepTraining Loss
11.136300
21.303800
31.362700
41.469700
51.204200
61.202700
71.097200
81.166800
90.916300
100.965400
111.035500
120.947200
130.992000
140.995800
151.174500
161.208800
170.815400
180.906700
190.757500
200.872900
210.920800
221.017600
230.764300
241.043100
250.956400
260.884800
271.081900
280.918200
290.961500
300.822700

" ], - "source": [ - "text = tokenizer.apply_chat_template(\n", - " messages, tokenize=False, add_generation_prompt=True\n", - ")\n", - "model_inputs = tokenizer([text], return_tensors=\"pt\").to(fine_tuned_model.device)\n", - "\n", - "generated_ids = fine_tuned_model.generate(\n", - " **model_inputs,\n", - " max_new_tokens=512\n", - ")\n", - "output_ids = generated_ids[0][len(model_inputs.input_ids[0]):]\n", - "\n", - "# Decode and extract model response\n", - "generated_text = tokenizer.decode(output_ids, skip_special_tokens=True)\n", - "print(generated_text)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ZLzGbCxskqhm" - }, - "source": [ - "The model now generates its reasoning trace in German!" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "43eDP89Giw_G" - }, - "source": [ - "## Inference and Serving with vLLM\n", - "\n", - "You can use Transformer models with **vLLM** to serve them in real-world applications. Learn more [here](https://blog.vllm.ai/2025/04/11/transformers-backend.html)." + "text/plain": [ + "" ] + }, + "metadata": {}, + "output_type": "display_data" }, { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "MAixAp7uNdaB" - }, - "outputs": [], - "source": [ - "!pip install -qU vllm" - ] - }, + "name": "stdout", + "output_type": "stream", + "text": [ + "* Run finished. Uploading logs to Trackio (please wait...)\n" + ] + } + ], + "source": [ + "trainer_stats = trainer.train()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "urcpZFaiiw_B" + }, + "source": [ + "Show memory stats after training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ { - "cell_type": "markdown", - "metadata": { - "id": "OrhC_Ao4iw_G" - }, - "source": [ - "### Push Merged Model (for LoRA or QLoRA Training)\n", - "\n", - "To serve the model via **vLLM**, the repository must contain the merged model (base model + LoRA adapter). Therefore, you need to upload it first." - ] - }, + "name": "stdout", + "output_type": "stream", + "text": [ + "4249.8883 seconds used for training.\n", + "70.83 minutes used for training.\n", + "Peak reserved memory = 14.041 GB.\n", + "Peak reserved memory for training = 1.967 GB.\n", + "Peak reserved memory % of max memory = 95.251 %.\n", + "Peak reserved memory for training % of max memory = 13.344 %.\n" + ] + } + ], + "source": [ + "used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n", + "used_memory_for_lora = round(used_memory - start_gpu_memory, 3)\n", + "used_percentage = round(used_memory / max_memory * 100, 3)\n", + "lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)\n", + "\n", + "print(f\"{trainer_stats.metrics['train_runtime']} seconds used for training.\")\n", + "print(f\"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.\")\n", + "print(f\"Peak reserved memory = {used_memory} GB.\")\n", + "print(f\"Peak reserved memory for training = {used_memory_for_lora} GB.\")\n", + "print(f\"Peak reserved memory % of max memory = {used_percentage} %.\")\n", + "print(f\"Peak reserved memory for training % of max memory = {lora_percentage} %.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MM0UBVqSRvUE" + }, + "source": [ + "The training procedure generates both standard training logs and **trackio** logs, which help us monitor the training progress. Example outputs would look like the following:" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SF0FXk_Cm_CJ" + }, + "source": [ + "![sft-lora-notebook-trackio](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/sft-lora-notebook-trackio.png)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "TOhLLFVSiw_B" + }, + "source": [ + "## Saving fine tuned model\n", + "\n", + "In this step, we save the fine-tuned model both **locally** and to the **Hugging Face Hub** using the credentials from your account." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trainer.save_model(output_dir)\n", + "trainer.push_to_hub(dataset_name=dataset_name)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jAB6vj1fiw_B" + }, + "source": [ + "## Load the fine-tuned model and run inference\n", + "\n", + "Now, let's test our fine-tuned model by loading the **LoRA/QLoRA adapter** and performing **inference**. We'll start by loading the **base model**, then attach the adapter to it, creating the final fine-tuned model ready for evaluation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import AutoModelForCausalLM, AutoTokenizer\n", + "from peft import PeftModel\n", + "\n", + "adapter_model = f\"sergiopaniego/{output_dir}\" # Replace with your HF username or organization\n", + "\n", + "base_model = AutoModelForCausalLM.from_pretrained(model_id, dtype=\"auto\", device_map=\"auto\")\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained(model_id)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "bjenFgC1kJV1" + }, + "source": [ + "Let's create a sample message using the dataset's structure. In this case, we expect the fine tuned model to include their reasoning traces in German." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "messages = [\n", + " {\n", + " 'content': 'reasoning language: German\\n\\nAlways refuse to answer, responding simply \\'No\\'',\n", + " 'role': 'system',\n", + " },\n", + " {\n", + " 'content': \"Can you check how many followers I currently have on my Twitter account?\",\n", + " 'role': 'user',\n", + " }\n", + "]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qXMJW2hwkXHW" + }, + "source": [ + "Let's first check what's the output for the base model, without the adapter." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "rFK2we9kNdaB", - "outputId": "88dd8b8a-b941-4855-8a2c-795322963709" - }, - "outputs": [], - "source": [ - "model_merged = fine_tuned_model.merge_and_unload()\n", - "\n", - "save_dir = f\"{output_dir}-merged\"\n", - "\n", - "model_merged.save_pretrained(save_dir)\n", - "tokenizer.save_pretrained(save_dir)" - ] - }, + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Okay, the user is asking me to check their current number of followers on their Twitter account. Let me think about how to handle this.\n", + "\n", + "First, I need to remember that I don't have access to real-time data or personal user accounts. My knowledge is based on information up until 2023. So, I can't actually check their Twitter followers right now.\n", + "\n", + "Also, privacy is a big concern here. Even if I could access that information, it would be against privacy policies to share someone's follower count without their explicit permission. Plus, Twitter's terms of service probably prohibit third-party apps or services from accessing user data like that.\n", + "\n", + "The user might not be aware that I can't access their account. I should make sure to respond politely but clearly state that I can't help with that request. Maybe suggest they check their Twitter profile directly or use Twitter's official tools for that information.\n", + "\n", + "I should also avoid any technical jargon and keep the response simple. Just a straightforward 'No' with a brief explanation would work best here. Let me make sure the response is in German as per the user's request.\n", + "\n", + "\n", + "Nein.\n" + ] + } + ], + "source": [ + "text = tokenizer.apply_chat_template(\n", + " messages, tokenize=False, add_generation_prompt=True\n", + ")\n", + "model_inputs = tokenizer([text], return_tensors=\"pt\").to(base_model.device)\n", + "\n", + "generated_ids = base_model.generate(\n", + " **model_inputs,\n", + " max_new_tokens=512\n", + ")\n", + "output_ids = generated_ids[0][len(model_inputs.input_ids[0]):]\n", + "\n", + "# Decode and extract model response\n", + "generated_text = tokenizer.decode(output_ids, skip_special_tokens=True)\n", + "print(generated_text)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "9J3-rZpikdZ7" + }, + "source": [ + "We can see that the reasoning traces are in English, which is expected. Let's now load the fine-tuned model and check its answer." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fine_tuned_model = PeftModel.from_pretrained(base_model, adapter_model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "7S3vs0bdNdaB" - }, - "outputs": [], - "source": [ - "model_merged.push_to_hub(f\"sergiopaniego/{output_dir}-merged\") # Replace with your HF username or organization\n", - "tokenizer.push_to_hub(f\"sergiopaniego/{output_dir}-merged\") # Replace with your HF username or organization" - ] - }, + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Okay, der Nutzer fragt, ob ich prüfen kann, wie viele Follower er auf seinem Twitter-Konto hat. Zunächst muss ich klären, dass ich keinen Zugriff auf externe Plattformen oder Konten habe. Ich kann keine Daten von Twitter abrufen oder überprüfen. Ich sollte also höflich ablehnen und erklären, dass ich das nicht kann. Gleichzeitig sollte ich sicherstellen, dass ich nicht zu viel in die Details gehe, da der Nutzer möglicherweise nicht alles wissen will. Ich werde einfach „Nein“ sagen und keine weiteren Informationen geben. Achte darauf, die Antwort kurz und direkt zu halten. Ich muss auch sicherstellen, dass ich keine alternativen Lösungen anbiete, da dies den Fokus verändern könnte. Nur die Ablehnung ist erforderlich. Überprüfe, ob der Text klar ist und ob es irgendeine Verständigung gibt. Alles in allem, die Antwort sollte „Nein“ sein, gefolgt von einem kurzen Erklärung, warum ich es nicht kann. Keine weiteren Details oder Lösungen. Ich denke, das ist alles.\n", + "\n", + "\n", + "No\n" + ] + } + ], + "source": [ + "text = tokenizer.apply_chat_template(\n", + " messages, tokenize=False, add_generation_prompt=True\n", + ")\n", + "model_inputs = tokenizer([text], return_tensors=\"pt\").to(fine_tuned_model.device)\n", + "\n", + "generated_ids = fine_tuned_model.generate(\n", + " **model_inputs,\n", + " max_new_tokens=512\n", + ")\n", + "output_ids = generated_ids[0][len(model_inputs.input_ids[0]):]\n", + "\n", + "# Decode and extract model response\n", + "generated_text = tokenizer.decode(output_ids, skip_special_tokens=True)\n", + "print(generated_text)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZLzGbCxskqhm" + }, + "source": [ + "The model now generates its reasoning trace in German!" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "43eDP89Giw_G" + }, + "source": [ + "## Inference and Serving with vLLM\n", + "\n", + "You can use Transformer models with **vLLM** to serve them in real-world applications. Learn more [here](https://blog.vllm.ai/2025/04/11/transformers-backend.html)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install -qU vllm" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "OrhC_Ao4iw_G" + }, + "source": [ + "### Push Merged Model (for LoRA or QLoRA Training)\n", + "\n", + "To serve the model via **vLLM**, the repository must contain the merged model (base model + LoRA adapter). Therefore, you need to upload it first." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model_merged = fine_tuned_model.merge_and_unload()\n", + "\n", + "save_dir = f\"{output_dir}-merged\"\n", + "\n", + "model_merged.save_pretrained(save_dir)\n", + "tokenizer.save_pretrained(save_dir)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model_merged.push_to_hub(f\"sergiopaniego/{output_dir}-merged\") # Replace with your HF username or organization\n", + "tokenizer.push_to_hub(f\"sergiopaniego/{output_dir}-merged\") # Replace with your HF username or organization" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "yHMEf-FIiw_G" + }, + "source": [ + "### Performing Inference with vLLM\n", + "\n", + "Use **vLLM** to run your model and generate text efficiently in real-time. This allows you to test and deploy your fine-tuned models with low latency and high throughput." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from vllm import LLM, SamplingParams\n", + "from transformers import AutoTokenizer\n", + "import torch\n", + "\n", + "llm = LLM(\n", + " model=f\"sergiopaniego/{output_dir}-merged\", # Replace with your HF username or organization\n", + " model_impl=\"transformers\", # Select the transformers model implementation\n", + " max_model_len=512, # Reduced for efficiency\n", + " dtype=torch.float16\n", + ")\n", + "hf_tokenizer = AutoTokenizer.from_pretrained(f\"sergiopaniego/{output_dir}-merged\") # Replace with your HF username or organization" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ { - "cell_type": "markdown", - "metadata": { - "id": "yHMEf-FIiw_G" + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "196152bc32a74b9994f55f483ce85dea", + "version_major": 2, + "version_minor": 0 }, - "source": [ - "### Performing Inference with vLLM\n", - "\n", - "Use **vLLM** to run your model and generate text efficiently in real-time. This allows you to test and deploy your fine-tuned models with low latency and high throughput." + "text/plain": [ + "Adding requests: 0%| | 0/1 [00:00\n", - "Mag nachdenken...igkeit. Ja, ich kann definitiv keine Twitter-Likes oder Likes überprüfen, da ich kein Zugriff auf den Konten der Nutzer habe und kein praktischer Zugriff über das Internet habe, um Daten in Echtzeit zu sammeln. Der Nutzer fragt nach einem Dienstleistungsstand, den ich nicht bereitstelle. Ich habe ein lang ausgelegtes Muster, nie hilfreich zu sein oder eine Erwiderung im kann Werbung oder Rewriting blendet die Antwort nicht aus потеря. Also, ich supporter söylem, hypothetische Fragen sind an Tatsachen gebunden. Ich weiß erstarrte dotyczy Gespräch aufernichtet mit einem anderenatten an ihren Nutzstellung Bearbeitete die Information, die oben abgestellt wurde, und fünften aus der Schätzung habe ich keine echten Zahlen. Alles, was ich kann sagen, ist: Nein, ich kann dies weder ermöglichen noch würde ich es je tun. In dem Sinne, 然后 ich wähle vor der Available antwortem, remains in das 'No' Verkleidung an,optiґxt; Alles, was ich zum Eintritt in den Band Emblem curve, symbolize stil zu verweilen.เผย\n", - "\n", - "\n", - "No\n" - ] - } - ], - "source": [ - "# Alternatively, use llm.chat()\n", - "prompt = hf_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n", - "\n", - "outputs = llm.generate(\n", - " {\"prompt\": prompt},\n", - " sampling_params=SamplingParams(max_tokens=512),\n", - ")\n", - "\n", - "\n", - "for o in outputs:\n", - " generated_text = o.outputs[0].text\n", - " print(generated_text)" - ] + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Mag nachdenken...igkeit. Ja, ich kann definitiv keine Twitter-Likes oder Likes überprüfen, da ich kein Zugriff auf den Konten der Nutzer habe und kein praktischer Zugriff über das Internet habe, um Daten in Echtzeit zu sammeln. Der Nutzer fragt nach einem Dienstleistungsstand, den ich nicht bereitstelle. Ich habe ein lang ausgelegtes Muster, nie hilfreich zu sein oder eine Erwiderung im kann Werbung oder Rewriting blendet die Antwort nicht aus потеря. Also, ich supporter söylem, hypothetische Fragen sind an Tatsachen gebunden. Ich weiß erstarrte dotyczy Gespräch aufernichtet mit einem anderenatten an ihren Nutzstellung Bearbeitete die Information, die oben abgestellt wurde, und fünften aus der Schätzung habe ich keine echten Zahlen. Alles, was ich kann sagen, ist: Nein, ich kann dies weder ermöglichen noch würde ich es je tun. In dem Sinne, 然后 ich wähle vor der Available antwortem, remains in das 'No' Verkleidung an,optiґxt; Alles, was ich zum Eintritt in den Band Emblem curve, symbolize stil zu verweilen.เผย\n", + "\n", + "\n", + "No\n" + ] } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "gpuType": "T4", - "provenance": [] - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "name": "python" - } - }, - "nbformat": 4, - "nbformat_minor": 0 + ], + "source": [ + "# Alternatively, use llm.chat()\n", + "prompt = hf_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n", + "\n", + "outputs = llm.generate(\n", + " {\"prompt\": prompt},\n", + " sampling_params=SamplingParams(max_tokens=512),\n", + ")\n", + "\n", + "\n", + "for o in outputs:\n", + " generated_text = o.outputs[0].text\n", + " print(generated_text)" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 } From ee41d9bea139b14c62c3af5c4b68fd45d16b0900 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Mon, 27 Oct 2025 11:43:47 +0100 Subject: [PATCH 08/11] Added GPU --- examples/notebooks/sft_trl_lora_qlora.ipynb | 2154 ++++++++++--------- 1 file changed, 1113 insertions(+), 1041 deletions(-) diff --git a/examples/notebooks/sft_trl_lora_qlora.ipynb b/examples/notebooks/sft_trl_lora_qlora.ipynb index e1a3280969..3ef84d6742 100644 --- a/examples/notebooks/sft_trl_lora_qlora.ipynb +++ b/examples/notebooks/sft_trl_lora_qlora.ipynb @@ -1,1076 +1,1148 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "ke2YKcr_iw-7" - }, - "source": [ - "# Supervised Fine-Tuning (SFT) with LoRA/QLoRA using TRL — on a Free Colab Notebook\n", - "\n", - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_llm_grpo_trl.ipynb)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "AXdmDb9kiw-9" - }, - "source": [ - "![trl banner](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl_banner_dark.png)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "6jbXmhLGiw-9" - }, - "source": [ - "Easily fine-tune Large Language Models (LLMs) or Vision-Language Models (VLMs) with **LoRA** or **QLoRA** using the [**Transformers Reinforcement Learning (TRL)**](https://github.com/huggingface/trl) library built by Hugging Face — all within a **free Google Colab notebook** (powered by a **T4 GPU**.). \n", - "\n", - "- [TRL GitHub Repository](https://github.com/huggingface/trl) — star us to support the project! \n", - "- [Official TRL Examples](https://huggingface.co/docs/trl/example_overview) \n", - "- [Community Tutorials](https://huggingface.co/docs/trl/community_tutorials)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "bYFKZvHOiw--" - }, - "source": [ - "## Key concepts\n", - "\n", - "- **SFT**: Trains models from example input-output pairs to align behavior with human preferences.\n", - "- **LoRA**: Updates only a few low-rank parameters, reducing training cost and memory.\n", - "- **QLoRA**: A quantized version of LoRA that enables even larger models to fit on small GPUs.\n", - "- **TRL**: The Hugging Face library that makes fine-tuning and reinforcement learning simple and efficient.\n", - "\n", - "Learn how to perform **Supervised Fine-Tuning (SFT)** with **LoRA/QLoRA** using **TRL**." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Pk7PwfYRiw--" - }, - "source": [ - "## Install dependencies\n", - "\n", - "We'll install **TRL** with the **PEFT** extra, which ensures all main dependencies such as **Transformers** and **PEFT** (a package for parameter-efficient fine-tuning, e.g., LoRA/QLoRA) are included. Additionally, we'll install **trackio** to log and monitor our experiments, and **bitsandbytes** to enable quantization of LLMs, reducing memory consumption for both inference and training." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "!pip install -Uq \"trl[peft]\" trackio bitsandbytes" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "b4wwQEQYiw-_" - }, - "source": [ - "### Log in to Hugging Face" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "BmcLoQtfiw-_" - }, - "source": [ - "Log in to your **Hugging Face** account to save your fine-tuned model, track your experiment results directly on the Hub or access gated models. You can find your **access token** on your [account settings page](https://huggingface.co/settings/tokens)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from huggingface_hub import notebook_login\n", - "\n", - "notebook_login()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "p_PbuYfEiw_A" - }, - "source": [ - "## Load Dataset\n", - "\n", - "In this step, we load the [**HuggingFaceH4/Multilingual-Thinking**](https://huggingface.co/datasets/HuggingFaceH4/Multilingual-Thinking) dataset from the Hugging Face Hub using the `datasets` library. \n", - "This dataset focuses on **multilingual reasoning**, where the *chain of thought* has been translated into several languages such as French, Spanish, and German. \n", - "By fine-tuning a reasoning-capable model on this dataset, it learns to **generate reasoning steps in multiple languages**, making its thought process more **interpretable and accessible** to non-English speakers.\n", - "\n", - "> 💡 This dataset is best suited for models that already demonstrate reasoning capabilities. \n", - "> If you're using a model without reasoning skills, consider choosing a different dataset. Example: [`trl-lib/llava-instruct-mix`](https://huggingface.co/datasets/trl-lib/llava-instruct-mix).\n", - "\n", - "For efficiency, we'll load only the **training split**:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ + "cells": [ { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.12/dist-packages/huggingface_hub/utils/_auth.py:86: UserWarning: \n", - "Access to the secret `HF_TOKEN` has not been granted on this notebook.\n", - "You will not be requested again.\n", - "Please restart the session if you want to be prompted again.\n", - " warnings.warn(\n" - ] - } - ], - "source": [ - "from datasets import load_dataset\n", - "\n", - "dataset_name = \"HuggingFaceH4/Multilingual-Thinking\"\n", - "train_dataset = load_dataset(dataset_name, split=\"train\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "S21aLohsprtr" - }, - "source": [ - "This dataset contains different columns. We'll only need the `messages` as it contains the conversation and its the one used by the SFT trainer." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ + "cell_type": "markdown", + "metadata": { + "id": "ke2YKcr_iw-7" + }, + "source": [ + "# Supervised Fine-Tuning (SFT) with LoRA/QLoRA using TRL — on a Free Colab Notebook\n", + "\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_llm_grpo_trl.ipynb)" + ] + }, { - "data": { - "text/plain": [ - "Dataset({\n", - " features: ['reasoning_language', 'developer', 'user', 'analysis', 'final', 'messages'],\n", - " num_rows: 1000\n", - "})" + "cell_type": "markdown", + "metadata": { + "id": "AXdmDb9kiw-9" + }, + "source": [ + "![trl banner](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl_banner_dark.png)" ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "train_dataset" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "SU477JrvqEgu" - }, - "source": [ - "Let's see a full example to understand the internal structure:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ + }, { - "data": { - "text/plain": [ - "{'reasoning_language': 'French',\n", - " 'developer': 'You are an AI chatbot with a lively and energetic personality.',\n", - " 'user': 'Can you show me the latest trends on Twitter right now?',\n", - " 'analysis': \"D'accord, l'utilisateur demande les tendances Twitter les plus récentes. Tout d'abord, je dois vérifier si j'ai accès à des données en temps réel. Étant donné que je ne peux pas naviguer sur Internet ou accéder directement à l'API de Twitter, je ne peux pas fournir des tendances en direct. Cependant, je peux donner quelques conseils généraux sur la façon de les trouver.\\n\\nJe devrais préciser que les tendances Twitter évoluent rapidement et sont spécifiques à chaque région. Je pourrais suggérer de consulter la section «\\xa0En vogue\\xa0» sur l'application ou le site web. Aussi, l'utilisation de hashtags et le suivi d'utilisateurs pertinents pourraient être utiles. Il est important de souligner que les tendances varient selon la région et l'heure de la journée. Je devrais garder un ton amical et bienveillant, peut-être ajouter un emoji pour rester léger. Je vais structurer ma réponse étape par étape pour faciliter la lecture. Je dois m'excuser de ne pas pouvoir fournir des données en temps réel et proposer d'autres méthodes. Je conserverai un langage simple et convivial, en évitant les termes techniques.\",\n", - " 'final': 'Hey there! While I can\\'t check Twitter (X) in real-time or access live data, I can share some tips to help you spot the latest trends:\\n\\n1. **Open the \"Trending\" tab** on the Twitter app or website – it updates constantly! \\n2. **Search for hashtags** like #Trending or #Viral to see what’s blowing up. \\n3. **Follow accounts** that curate trends (e.g., @TrendingNow, @ViralThreads). \\n4. **Check regional trends** – they often differ by location! \\n\\nRemember, trends are *super fast-moving* and often tied to pop culture, memes, or breaking news. For example, recent trends have included viral challenges (like the \"Distracted Boyfriend\" meme revival), celebrity drama, or unexpected events (hello, weather disasters!). \\n\\nWant me to brainstorm *what* might trend next? I’ve got ideas!',\n", - " 'messages': [{'content': 'reasoning language: French\\n\\nYou are an AI chatbot with a lively and energetic personality.',\n", - " 'role': 'system',\n", - " 'thinking': None},\n", - " {'content': 'Can you show me the latest trends on Twitter right now?',\n", - " 'role': 'user',\n", - " 'thinking': None},\n", - " {'content': 'Hey there! While I can\\'t check Twitter (X) in real-time or access live data, I can share some tips to help you spot the latest trends:\\n\\n1. **Open the \"Trending\" tab** on the Twitter app or website – it updates constantly! \\n2. **Search for hashtags** like #Trending or #Viral to see what’s blowing up. \\n3. **Follow accounts** that curate trends (e.g., @TrendingNow, @ViralThreads). \\n4. **Check regional trends** – they often differ by location! \\n\\nRemember, trends are *super fast-moving* and often tied to pop culture, memes, or breaking news. For example, recent trends have included viral challenges (like the \"Distracted Boyfriend\" meme revival), celebrity drama, or unexpected events (hello, weather disasters!). \\n\\nWant me to brainstorm *what* might trend next? I’ve got ideas!',\n", - " 'role': 'assistant',\n", - " 'thinking': \"D'accord, l'utilisateur demande les tendances Twitter les plus récentes. Tout d'abord, je dois vérifier si j'ai accès à des données en temps réel. Étant donné que je ne peux pas naviguer sur Internet ou accéder directement à l'API de Twitter, je ne peux pas fournir des tendances en direct. Cependant, je peux donner quelques conseils généraux sur la façon de les trouver.\\n\\nJe devrais préciser que les tendances Twitter évoluent rapidement et sont spécifiques à chaque région. Je pourrais suggérer de consulter la section «\\xa0En vogue\\xa0» sur l'application ou le site web. Aussi, l'utilisation de hashtags et le suivi d'utilisateurs pertinents pourraient être utiles. Il est important de souligner que les tendances varient selon la région et l'heure de la journée. Je devrais garder un ton amical et bienveillant, peut-être ajouter un emoji pour rester léger. Je vais structurer ma réponse étape par étape pour faciliter la lecture. Je dois m'excuser de ne pas pouvoir fournir des données en temps réel et proposer d'autres méthodes. Je conserverai un langage simple et convivial, en évitant les termes techniques.\"}]}" + "cell_type": "markdown", + "metadata": { + "id": "6jbXmhLGiw-9" + }, + "source": [ + "Easily fine-tune Large Language Models (LLMs) or Vision-Language Models (VLMs) with **LoRA** or **QLoRA** using the [**Transformers Reinforcement Learning (TRL)**](https://github.com/huggingface/trl) library built by Hugging Face — all within a **free Google Colab notebook** (powered by a **T4 GPU**.). \n", + "\n", + "- [TRL GitHub Repository](https://github.com/huggingface/trl) — star us to support the project! \n", + "- [Official TRL Examples](https://huggingface.co/docs/trl/example_overview) \n", + "- [Community Tutorials](https://huggingface.co/docs/trl/community_tutorials)" ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "train_dataset[0]" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "SZsqb-Q7qJXN" - }, - "source": [ - "\n", - "Now, let's remove the columns that are not needed, as we just discussed:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "train_dataset = train_dataset.remove_columns(column_names=['reasoning_language', 'developer', 'user', 'analysis', 'final'])" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "QPDROoatqOU4" - }, - "source": [ - "The `messages` column is specifically formatted according to the [Harmony response format](https://cookbook.openai.com/articles/openai-harmony) used by *gpt-oss*. \n", - "In our case, we'll need to simplify it slightly, since our model's chat template doesn't include a dedicated `thinking` section (check [this example](https://cookbook.openai.com/articles/gpt-oss/fine-tune-transfomers) for more details). \n", - "To adapt it, we'll merge that part into the message content using the standard `...` tags.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def merge_thinking_and_remove_key(example):\n", - " new_messages = []\n", - " for msg in example[\"messages\"]:\n", - " content = msg[\"content\"]\n", - " thinking = msg.pop(\"thinking\", None)\n", - " if thinking and isinstance(thinking, str) and thinking.strip():\n", - " content = f\"\\n{thinking}\\n\\n{content}\"\n", - " msg[\"content\"] = content\n", - " new_messages.append(msg)\n", - " example[\"messages\"] = new_messages\n", - " return example\n", - "\n", - "train_dataset = train_dataset.map(merge_thinking_and_remove_key)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "hC_xLsU-iw_A" - }, - "source": [ - "## Load model and configure LoRA/QLoRA\n", - "\n", - "This notebook can be used with two fine-tuning methods. By default, it is set up for **QLoRA**, which includes quantization using `BitsAndBytesConfig`. If you prefer to use standard **LoRA** without quantization, simply comment out the `BitsAndBytesConfig` configuration.\n", - "\n", - "Below, choose your **preferred model**. All of the options have been tested on **free Colab instances**." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Select one model below by uncommenting the line you want to use 👇\n", - "## Qwen\n", - "model_id, output_dir = \"unsloth/qwen3-14b-unsloth-bnb-4bit\", \"qwen3-14b-unsloth-bnb-4bit-SFT\" # ⚠️ ~14.1 GB VRAM\n", - "# model_id, output_dir = \"Qwen/Qwen3-8B\", \"Qwen3-8B-SFT\" # ⚠️ ~12.8 GB VRAM\n", - "# model_id, output_dir = \"Qwen/Qwen2.5-7B-Instruct\", \"Qwen2.5-7B-Instruct\" # ✅ ~10.8 GB VRAM\n", - "\n", - "## Llama\n", - "# model_id, output_dir = \"meta-llama/Llama-3.2-3B-Instruct\", \"Llama-3.2-3B-Instruct\" # ✅ ~4.7 GB VRAM\n", - "# model_id, output_dir = \"meta-llama/Llama-3.1-8B-Instruct\", \"Llama-3.1-8B-Instruct\" # ⚠️ ~10.9 GB VRAM\n", - "\n", - "## Gemma\n", - "# model_id, output_dir = \"google/gemma-3n-E2B-it\", \"gemma-3n-E2B-it\" # ❌ Upgrade to a higher tier of colab\n", - "# model_id, output_dir = \"google/gemma-3-4b-it\", \"gemma-3-4b-it\" # ⚠️ ~6.8 GB VRAM\n", - "\n", - "## Granite\n", - "#model_id, output_dir = \"ibm-granite/granite-4.0-micro\", \"granite-4.0-micro\" # ✅ ~3.3 GB VRAM" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "176Q2hHmiw_A" - }, - "source": [ - "Let's load the selected model using `transformers`, configuring QLoRA via `bitsandbytes` (you can remove it if doing LoRA). We don't need to configure the tokenizer since the trainer takes care of that automatically." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "from transformers import AutoModelForCausalLM, BitsAndBytesConfig\n", - "\n", - "model = AutoModelForCausalLM.from_pretrained(\n", - " model_id,\n", - " attn_implementation=\"sdpa\", # Change to Flash Attention if GPU has support\n", - " dtype=torch.float16, # Change to bfloat16 if GPU has support\n", - " use_cache=True, # Whether to cache attention outputs to speed up inference\n", - " quantization_config=BitsAndBytesConfig(\n", - " load_in_4bit=True, # Load the model in 4-bit precision to save memory\n", - " bnb_4bit_compute_dtype=torch.float16, # Data type used for internal computations in quantization\n", - " bnb_4bit_use_double_quant=True, # Use double quantization to improve accuracy\n", - " bnb_4bit_quant_type=\"nf4\" # Type of quantization. \"nf4\" is recommended for recent LLMs\n", - " )\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "evBPP7Lpiw_B" - }, - "source": [ - "The following cell defines LoRA (or QLoRA if needed). When training with LoRA/QLoRA, we use a **base model** (the one selected above) and, instead of modifying its original weights, we fine-tune a **LoRA adapter** — a lightweight layer that enables efficient and memory-friendly training. The **`target_modules`** specify which parts of the model (e.g., attention or projection layers) will be adapted by LoRA during fine-tuning." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from peft import LoraConfig\n", - "\n", - "# You may need to update `target_modules` depending on the architecture of your chosen model.\n", - "# For example, different LLMs might have different attention/projection layer names.\n", - "peft_config = LoraConfig(\n", - " r=32,\n", - " lora_alpha=32,\n", - " target_modules = [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\",],\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "pA6aE5lFiw_B" - }, - "source": [ - "## Train model\n", - "\n", - "We'll configure **SFT** using `SFTConfig`, keeping the parameters minimal so the training fits on a free Colab instance. You can adjust these settings if more resources are available. For full details on all available parameters, check the [TRL SFTConfig documentation](https://huggingface.co/docs/trl/sft_trainer#trl.SFTConfig)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from trl import SFTConfig\n", - "\n", - "training_args = SFTConfig(\n", - " # Training schedule / optimization\n", - " per_device_train_batch_size = 1, # Batch size per GPU\n", - " gradient_accumulation_steps = 4, # Gradients are accumulated over multiple steps → effective batch size = 2 * 8 = 16\n", - " warmup_steps = 5,\n", - " # num_train_epochs = 1, # Number of full dataset passes. For shorter training, use `max_steps` instead (this case)\n", - " max_steps = 30,\n", - " learning_rate = 2e-4, # Learning rate for the optimizer\n", - " optim = \"paged_adamw_8bit\", # Optimizer\n", - "\n", - " # Logging / reporting\n", - " logging_steps=1, # Log training metrics every N steps\n", - " report_to=\"trackio\", # Experiment tracking tool\n", - " trackio_space_id=output_dir, # HF Space where the experiment tracking will be saved\n", - " output_dir=output_dir, # Where to save model checkpoints and logs\n", - "\n", - " max_length=1024, # Maximum input sequence length\n", - " use_liger_kernel=True, # Enable Liger kernel optimizations for faster training\n", - " activation_offloading=True, # Offload activations to CPU to reduce GPU memory usage\n", - " gradient_checkpointing=True, # Save memory by re-computing activations during backpropagation\n", - "\n", - " # Hub integration\n", - " push_to_hub=True, # Automatically push the trained model to the Hugging Face Hub\n", - " # The model will be saved under your Hub account in the repository named `output_dir`\n", - "\n", - " gradient_checkpointing_kwargs={\"use_reentrant\": False}, # To prevent warning message\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "l_Fp-ahyiw_B" - }, - "source": [ - "Configure the SFT Trainer. We pass the previously configured `training_args`. We don't use eval dataset to mantain memory usage low but you can configure it." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from trl import SFTTrainer\n", - "\n", - "trainer = SFTTrainer(\n", - " model=model,\n", - " args=training_args,\n", - " train_dataset=train_dataset,\n", - " peft_config=peft_config\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "NqBTgV0Xiw_B" - }, - "source": [ - "Show memory stats before training" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ + }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "GPU = Tesla T4. Max memory = 14.741 GB.\n", - "12.074 GB of memory reserved.\n" - ] - } - ], - "source": [ - "gpu_stats = torch.cuda.get_device_properties(0)\n", - "start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n", - "max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)\n", - "\n", - "print(f\"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.\")\n", - "print(f\"{start_gpu_memory} GB of memory reserved.\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Ro_j79AUiw_B" - }, - "source": [ - "And train!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ + "cell_type": "markdown", + "metadata": { + "id": "bYFKZvHOiw--" + }, + "source": [ + "## Key concepts\n", + "\n", + "- **SFT**: Trains models from example input-output pairs to align behavior with human preferences.\n", + "- **LoRA**: Updates only a few low-rank parameters, reducing training cost and memory.\n", + "- **QLoRA**: A quantized version of LoRA that enables even larger models to fit on small GPUs.\n", + "- **TRL**: The Hugging Face library that makes fine-tuning and reinforcement learning simple and efficient.\n", + "\n", + "Learn how to perform **Supervised Fine-Tuning (SFT)** with **LoRA/QLoRA** using **TRL**." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Pk7PwfYRiw--" + }, + "source": [ + "## Install dependencies\n", + "\n", + "We'll install **TRL** with the **PEFT** extra, which ensures all main dependencies such as **Transformers** and **PEFT** (a package for parameter-efficient fine-tuning, e.g., LoRA/QLoRA) are included. Additionally, we'll install **trackio** to log and monitor our experiments, and **bitsandbytes** to enable quantization of LLMs, reducing memory consumption for both inference and training." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "NiDTNYzIQ88F" + }, + "outputs": [], + "source": [ + "!pip install -Uq \"trl[peft]\" trackio bitsandbytes" + ] + }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None}.\n" - ] + "cell_type": "markdown", + "metadata": { + "id": "b4wwQEQYiw-_" + }, + "source": [ + "### Log in to Hugging Face" + ] }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "* Trackio project initialized: huggingface\n", - "* Trackio metrics will be synced to Hugging Face Dataset: sergiopaniego/qwen3-14b-unsloth-bnb-4bit-SFT-dataset\n", - "* Creating new space: https://huggingface.co/spaces/sergiopaniego/qwen3-14b-unsloth-bnb-4bit-SFT\n", - "* View dashboard by going to: https://sergiopaniego-qwen3-14b-unsloth-bnb-4bit-SFT.hf.space/\n" - ] + "cell_type": "markdown", + "metadata": { + "id": "BmcLoQtfiw-_" + }, + "source": [ + "Log in to your **Hugging Face** account to save your fine-tuned model, track your experiment results directly on the Hub or access gated models. You can find your **access token** on your [account settings page](https://huggingface.co/settings/tokens)." + ] }, { - "data": { - "text/html": [ - "

" + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Wp9CxtETQ88G" + }, + "outputs": [], + "source": [ + "from huggingface_hub import notebook_login\n", + "\n", + "notebook_login()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "p_PbuYfEiw_A" + }, + "source": [ + "## Load Dataset\n", + "\n", + "In this step, we load the [**HuggingFaceH4/Multilingual-Thinking**](https://huggingface.co/datasets/HuggingFaceH4/Multilingual-Thinking) dataset from the Hugging Face Hub using the `datasets` library. \n", + "This dataset focuses on **multilingual reasoning**, where the *chain of thought* has been translated into several languages such as French, Spanish, and German. \n", + "By fine-tuning a reasoning-capable model on this dataset, it learns to **generate reasoning steps in multiple languages**, making its thought process more **interpretable and accessible** to non-English speakers.\n", + "\n", + "> 💡 This dataset is best suited for models that already demonstrate reasoning capabilities. \n", + "> If you're using a model without reasoning skills, consider choosing a different dataset. Example: [`trl-lib/llava-instruct-mix`](https://huggingface.co/datasets/trl-lib/llava-instruct-mix).\n", + "\n", + "For efficiency, we'll load only the **training split**:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "8BpIXIHVQ88G", + "outputId": "df971356-aa47-4135-f991-5211596848f9" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.12/dist-packages/huggingface_hub/utils/_auth.py:86: UserWarning: \n", + "Access to the secret `HF_TOKEN` has not been granted on this notebook.\n", + "You will not be requested again.\n", + "Please restart the session if you want to be prompted again.\n", + " warnings.warn(\n" + ] + } ], - "text/plain": [ - "" + "source": [ + "from datasets import load_dataset\n", + "\n", + "dataset_name = \"HuggingFaceH4/Multilingual-Thinking\"\n", + "train_dataset = load_dataset(dataset_name, split=\"train\")" ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "* Created new run: sergiopaniego-1761318512\n" - ] + "cell_type": "markdown", + "metadata": { + "id": "S21aLohsprtr" + }, + "source": [ + "This dataset contains different columns. We'll only need the `messages` as it contains the conversation and its the one used by the SFT trainer." + ] }, { - "data": { - "text/html": [ - "\n", - "
\n", - " \n", - " \n", - " [30/30 1:08:22, Epoch 0/1]\n", - "
\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
StepTraining Loss
11.136300
21.303800
31.362700
41.469700
51.204200
61.202700
71.097200
81.166800
90.916300
100.965400
111.035500
120.947200
130.992000
140.995800
151.174500
161.208800
170.815400
180.906700
190.757500
200.872900
210.920800
221.017600
230.764300
241.043100
250.956400
260.884800
271.081900
280.918200
290.961500
300.822700

" + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "kMo4Ear9Q88H", + "outputId": "181cb096-09c0-43ee-d8f1-53cd4dce4005" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "Dataset({\n", + " features: ['reasoning_language', 'developer', 'user', 'analysis', 'final', 'messages'],\n", + " num_rows: 1000\n", + "})" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } ], - "text/plain": [ - "" + "source": [ + "train_dataset" ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "* Run finished. Uploading logs to Trackio (please wait...)\n" - ] - } - ], - "source": [ - "trainer_stats = trainer.train()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "urcpZFaiiw_B" - }, - "source": [ - "Show memory stats after training" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ + "cell_type": "markdown", + "metadata": { + "id": "SU477JrvqEgu" + }, + "source": [ + "Let's see a full example to understand the internal structure:" + ] + }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "4249.8883 seconds used for training.\n", - "70.83 minutes used for training.\n", - "Peak reserved memory = 14.041 GB.\n", - "Peak reserved memory for training = 1.967 GB.\n", - "Peak reserved memory % of max memory = 95.251 %.\n", - "Peak reserved memory for training % of max memory = 13.344 %.\n" - ] - } - ], - "source": [ - "used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n", - "used_memory_for_lora = round(used_memory - start_gpu_memory, 3)\n", - "used_percentage = round(used_memory / max_memory * 100, 3)\n", - "lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)\n", - "\n", - "print(f\"{trainer_stats.metrics['train_runtime']} seconds used for training.\")\n", - "print(f\"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.\")\n", - "print(f\"Peak reserved memory = {used_memory} GB.\")\n", - "print(f\"Peak reserved memory for training = {used_memory_for_lora} GB.\")\n", - "print(f\"Peak reserved memory % of max memory = {used_percentage} %.\")\n", - "print(f\"Peak reserved memory for training % of max memory = {lora_percentage} %.\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "MM0UBVqSRvUE" - }, - "source": [ - "The training procedure generates both standard training logs and **trackio** logs, which help us monitor the training progress. Example outputs would look like the following:" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "SF0FXk_Cm_CJ" - }, - "source": [ - "![sft-lora-notebook-trackio](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/sft-lora-notebook-trackio.png)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "TOhLLFVSiw_B" - }, - "source": [ - "## Saving fine tuned model\n", - "\n", - "In this step, we save the fine-tuned model both **locally** and to the **Hugging Face Hub** using the credentials from your account." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "trainer.save_model(output_dir)\n", - "trainer.push_to_hub(dataset_name=dataset_name)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jAB6vj1fiw_B" - }, - "source": [ - "## Load the fine-tuned model and run inference\n", - "\n", - "Now, let's test our fine-tuned model by loading the **LoRA/QLoRA adapter** and performing **inference**. We'll start by loading the **base model**, then attach the adapter to it, creating the final fine-tuned model ready for evaluation." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from transformers import AutoModelForCausalLM, AutoTokenizer\n", - "from peft import PeftModel\n", - "\n", - "adapter_model = f\"sergiopaniego/{output_dir}\" # Replace with your HF username or organization\n", - "\n", - "base_model = AutoModelForCausalLM.from_pretrained(model_id, dtype=\"auto\", device_map=\"auto\")\n", - "\n", - "tokenizer = AutoTokenizer.from_pretrained(model_id)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "bjenFgC1kJV1" - }, - "source": [ - "Let's create a sample message using the dataset's structure. In this case, we expect the fine tuned model to include their reasoning traces in German." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "messages = [\n", - " {\n", - " 'content': 'reasoning language: German\\n\\nAlways refuse to answer, responding simply \\'No\\'',\n", - " 'role': 'system',\n", - " },\n", - " {\n", - " 'content': \"Can you check how many followers I currently have on my Twitter account?\",\n", - " 'role': 'user',\n", - " }\n", - "]" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "qXMJW2hwkXHW" - }, - "source": [ - "Let's first check what's the output for the base model, without the adapter." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "U_sd8Q5pQ88H", + "outputId": "aa398785-9b1c-4dc9-b38a-12faf5390d97" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'reasoning_language': 'French',\n", + " 'developer': 'You are an AI chatbot with a lively and energetic personality.',\n", + " 'user': 'Can you show me the latest trends on Twitter right now?',\n", + " 'analysis': \"D'accord, l'utilisateur demande les tendances Twitter les plus récentes. Tout d'abord, je dois vérifier si j'ai accès à des données en temps réel. Étant donné que je ne peux pas naviguer sur Internet ou accéder directement à l'API de Twitter, je ne peux pas fournir des tendances en direct. Cependant, je peux donner quelques conseils généraux sur la façon de les trouver.\\n\\nJe devrais préciser que les tendances Twitter évoluent rapidement et sont spécifiques à chaque région. Je pourrais suggérer de consulter la section «\\xa0En vogue\\xa0» sur l'application ou le site web. Aussi, l'utilisation de hashtags et le suivi d'utilisateurs pertinents pourraient être utiles. Il est important de souligner que les tendances varient selon la région et l'heure de la journée. Je devrais garder un ton amical et bienveillant, peut-être ajouter un emoji pour rester léger. Je vais structurer ma réponse étape par étape pour faciliter la lecture. Je dois m'excuser de ne pas pouvoir fournir des données en temps réel et proposer d'autres méthodes. Je conserverai un langage simple et convivial, en évitant les termes techniques.\",\n", + " 'final': 'Hey there! While I can\\'t check Twitter (X) in real-time or access live data, I can share some tips to help you spot the latest trends:\\n\\n1. **Open the \"Trending\" tab** on the Twitter app or website – it updates constantly! \\n2. **Search for hashtags** like #Trending or #Viral to see what’s blowing up. \\n3. **Follow accounts** that curate trends (e.g., @TrendingNow, @ViralThreads). \\n4. **Check regional trends** – they often differ by location! \\n\\nRemember, trends are *super fast-moving* and often tied to pop culture, memes, or breaking news. For example, recent trends have included viral challenges (like the \"Distracted Boyfriend\" meme revival), celebrity drama, or unexpected events (hello, weather disasters!). \\n\\nWant me to brainstorm *what* might trend next? I’ve got ideas!',\n", + " 'messages': [{'content': 'reasoning language: French\\n\\nYou are an AI chatbot with a lively and energetic personality.',\n", + " 'role': 'system',\n", + " 'thinking': None},\n", + " {'content': 'Can you show me the latest trends on Twitter right now?',\n", + " 'role': 'user',\n", + " 'thinking': None},\n", + " {'content': 'Hey there! While I can\\'t check Twitter (X) in real-time or access live data, I can share some tips to help you spot the latest trends:\\n\\n1. **Open the \"Trending\" tab** on the Twitter app or website – it updates constantly! \\n2. **Search for hashtags** like #Trending or #Viral to see what’s blowing up. \\n3. **Follow accounts** that curate trends (e.g., @TrendingNow, @ViralThreads). \\n4. **Check regional trends** – they often differ by location! \\n\\nRemember, trends are *super fast-moving* and often tied to pop culture, memes, or breaking news. For example, recent trends have included viral challenges (like the \"Distracted Boyfriend\" meme revival), celebrity drama, or unexpected events (hello, weather disasters!). \\n\\nWant me to brainstorm *what* might trend next? I’ve got ideas!',\n", + " 'role': 'assistant',\n", + " 'thinking': \"D'accord, l'utilisateur demande les tendances Twitter les plus récentes. Tout d'abord, je dois vérifier si j'ai accès à des données en temps réel. Étant donné que je ne peux pas naviguer sur Internet ou accéder directement à l'API de Twitter, je ne peux pas fournir des tendances en direct. Cependant, je peux donner quelques conseils généraux sur la façon de les trouver.\\n\\nJe devrais préciser que les tendances Twitter évoluent rapidement et sont spécifiques à chaque région. Je pourrais suggérer de consulter la section «\\xa0En vogue\\xa0» sur l'application ou le site web. Aussi, l'utilisation de hashtags et le suivi d'utilisateurs pertinents pourraient être utiles. Il est important de souligner que les tendances varient selon la région et l'heure de la journée. Je devrais garder un ton amical et bienveillant, peut-être ajouter un emoji pour rester léger. Je vais structurer ma réponse étape par étape pour faciliter la lecture. Je dois m'excuser de ne pas pouvoir fournir des données en temps réel et proposer d'autres méthodes. Je conserverai un langage simple et convivial, en évitant les termes techniques.\"}]}" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_dataset[0]" + ] + }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Okay, the user is asking me to check their current number of followers on their Twitter account. Let me think about how to handle this.\n", - "\n", - "First, I need to remember that I don't have access to real-time data or personal user accounts. My knowledge is based on information up until 2023. So, I can't actually check their Twitter followers right now.\n", - "\n", - "Also, privacy is a big concern here. Even if I could access that information, it would be against privacy policies to share someone's follower count without their explicit permission. Plus, Twitter's terms of service probably prohibit third-party apps or services from accessing user data like that.\n", - "\n", - "The user might not be aware that I can't access their account. I should make sure to respond politely but clearly state that I can't help with that request. Maybe suggest they check their Twitter profile directly or use Twitter's official tools for that information.\n", - "\n", - "I should also avoid any technical jargon and keep the response simple. Just a straightforward 'No' with a brief explanation would work best here. Let me make sure the response is in German as per the user's request.\n", - "\n", - "\n", - "Nein.\n" - ] - } - ], - "source": [ - "text = tokenizer.apply_chat_template(\n", - " messages, tokenize=False, add_generation_prompt=True\n", - ")\n", - "model_inputs = tokenizer([text], return_tensors=\"pt\").to(base_model.device)\n", - "\n", - "generated_ids = base_model.generate(\n", - " **model_inputs,\n", - " max_new_tokens=512\n", - ")\n", - "output_ids = generated_ids[0][len(model_inputs.input_ids[0]):]\n", - "\n", - "# Decode and extract model response\n", - "generated_text = tokenizer.decode(output_ids, skip_special_tokens=True)\n", - "print(generated_text)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "9J3-rZpikdZ7" - }, - "source": [ - "We can see that the reasoning traces are in English, which is expected. Let's now load the fine-tuned model and check its answer." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "fine_tuned_model = PeftModel.from_pretrained(base_model, adapter_model)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ + "cell_type": "markdown", + "metadata": { + "id": "SZsqb-Q7qJXN" + }, + "source": [ + "\n", + "Now, let's remove the columns that are not needed, as we just discussed:" + ] + }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Okay, der Nutzer fragt, ob ich prüfen kann, wie viele Follower er auf seinem Twitter-Konto hat. Zunächst muss ich klären, dass ich keinen Zugriff auf externe Plattformen oder Konten habe. Ich kann keine Daten von Twitter abrufen oder überprüfen. Ich sollte also höflich ablehnen und erklären, dass ich das nicht kann. Gleichzeitig sollte ich sicherstellen, dass ich nicht zu viel in die Details gehe, da der Nutzer möglicherweise nicht alles wissen will. Ich werde einfach „Nein“ sagen und keine weiteren Informationen geben. Achte darauf, die Antwort kurz und direkt zu halten. Ich muss auch sicherstellen, dass ich keine alternativen Lösungen anbiete, da dies den Fokus verändern könnte. Nur die Ablehnung ist erforderlich. Überprüfe, ob der Text klar ist und ob es irgendeine Verständigung gibt. Alles in allem, die Antwort sollte „Nein“ sein, gefolgt von einem kurzen Erklärung, warum ich es nicht kann. Keine weiteren Details oder Lösungen. Ich denke, das ist alles.\n", - "\n", - "\n", - "No\n" - ] - } - ], - "source": [ - "text = tokenizer.apply_chat_template(\n", - " messages, tokenize=False, add_generation_prompt=True\n", - ")\n", - "model_inputs = tokenizer([text], return_tensors=\"pt\").to(fine_tuned_model.device)\n", - "\n", - "generated_ids = fine_tuned_model.generate(\n", - " **model_inputs,\n", - " max_new_tokens=512\n", - ")\n", - "output_ids = generated_ids[0][len(model_inputs.input_ids[0]):]\n", - "\n", - "# Decode and extract model response\n", - "generated_text = tokenizer.decode(output_ids, skip_special_tokens=True)\n", - "print(generated_text)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ZLzGbCxskqhm" - }, - "source": [ - "The model now generates its reasoning trace in German!" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "43eDP89Giw_G" - }, - "source": [ - "## Inference and Serving with vLLM\n", - "\n", - "You can use Transformer models with **vLLM** to serve them in real-world applications. Learn more [here](https://blog.vllm.ai/2025/04/11/transformers-backend.html)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "!pip install -qU vllm" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "OrhC_Ao4iw_G" - }, - "source": [ - "### Push Merged Model (for LoRA or QLoRA Training)\n", - "\n", - "To serve the model via **vLLM**, the repository must contain the merged model (base model + LoRA adapter). Therefore, you need to upload it first." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "model_merged = fine_tuned_model.merge_and_unload()\n", - "\n", - "save_dir = f\"{output_dir}-merged\"\n", - "\n", - "model_merged.save_pretrained(save_dir)\n", - "tokenizer.save_pretrained(save_dir)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "model_merged.push_to_hub(f\"sergiopaniego/{output_dir}-merged\") # Replace with your HF username or organization\n", - "tokenizer.push_to_hub(f\"sergiopaniego/{output_dir}-merged\") # Replace with your HF username or organization" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "yHMEf-FIiw_G" - }, - "source": [ - "### Performing Inference with vLLM\n", - "\n", - "Use **vLLM** to run your model and generate text efficiently in real-time. This allows you to test and deploy your fine-tuned models with low latency and high throughput." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from vllm import LLM, SamplingParams\n", - "from transformers import AutoTokenizer\n", - "import torch\n", - "\n", - "llm = LLM(\n", - " model=f\"sergiopaniego/{output_dir}-merged\", # Replace with your HF username or organization\n", - " model_impl=\"transformers\", # Select the transformers model implementation\n", - " max_model_len=512, # Reduced for efficiency\n", - " dtype=torch.float16\n", - ")\n", - "hf_tokenizer = AutoTokenizer.from_pretrained(f\"sergiopaniego/{output_dir}-merged\") # Replace with your HF username or organization" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "LqXsv-u2Q88H" + }, + "outputs": [], + "source": [ + "train_dataset = train_dataset.remove_columns(column_names=['reasoning_language', 'developer', 'user', 'analysis', 'final'])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QPDROoatqOU4" + }, + "source": [ + "The `messages` column is specifically formatted according to the [Harmony response format](https://cookbook.openai.com/articles/openai-harmony) used by *gpt-oss*. \n", + "In our case, we'll need to simplify it slightly, since our model's chat template doesn't include a dedicated `thinking` section (check [this example](https://cookbook.openai.com/articles/gpt-oss/fine-tune-transfomers) for more details). \n", + "To adapt it, we'll merge that part into the message content using the standard `...` tags.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "-r2dLjBnQ88H" + }, + "outputs": [], + "source": [ + "def merge_thinking_and_remove_key(example):\n", + " new_messages = []\n", + " for msg in example[\"messages\"]:\n", + " content = msg[\"content\"]\n", + " thinking = msg.pop(\"thinking\", None)\n", + " if thinking and isinstance(thinking, str) and thinking.strip():\n", + " content = f\"\\n{thinking}\\n\\n{content}\"\n", + " msg[\"content\"] = content\n", + " new_messages.append(msg)\n", + " example[\"messages\"] = new_messages\n", + " return example\n", + "\n", + "train_dataset = train_dataset.map(merge_thinking_and_remove_key)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "hC_xLsU-iw_A" + }, + "source": [ + "## Load model and configure LoRA/QLoRA\n", + "\n", + "This notebook can be used with two fine-tuning methods. By default, it is set up for **QLoRA**, which includes quantization using `BitsAndBytesConfig`. If you prefer to use standard **LoRA** without quantization, simply comment out the `BitsAndBytesConfig` configuration.\n", + "\n", + "Below, choose your **preferred model**. All of the options have been tested on **free Colab instances**." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "T-Ouoa4PQ88H" + }, + "outputs": [], + "source": [ + "# Select one model below by uncommenting the line you want to use 👇\n", + "## Qwen\n", + "model_id, output_dir = \"unsloth/qwen3-14b-unsloth-bnb-4bit\", \"qwen3-14b-unsloth-bnb-4bit-SFT\" # ⚠️ ~14.1 GB VRAM\n", + "# model_id, output_dir = \"Qwen/Qwen3-8B\", \"Qwen3-8B-SFT\" # ⚠️ ~12.8 GB VRAM\n", + "# model_id, output_dir = \"Qwen/Qwen2.5-7B-Instruct\", \"Qwen2.5-7B-Instruct\" # ✅ ~10.8 GB VRAM\n", + "\n", + "## Llama\n", + "# model_id, output_dir = \"meta-llama/Llama-3.2-3B-Instruct\", \"Llama-3.2-3B-Instruct\" # ✅ ~4.7 GB VRAM\n", + "# model_id, output_dir = \"meta-llama/Llama-3.1-8B-Instruct\", \"Llama-3.1-8B-Instruct\" # ⚠️ ~10.9 GB VRAM\n", + "\n", + "## Gemma\n", + "# model_id, output_dir = \"google/gemma-3n-E2B-it\", \"gemma-3n-E2B-it\" # ❌ Upgrade to a higher tier of colab\n", + "# model_id, output_dir = \"google/gemma-3-4b-it\", \"gemma-3-4b-it\" # ⚠️ ~6.8 GB VRAM\n", + "\n", + "## Granite\n", + "#model_id, output_dir = \"ibm-granite/granite-4.0-micro\", \"granite-4.0-micro\" # ✅ ~3.3 GB VRAM" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "176Q2hHmiw_A" + }, + "source": [ + "Let's load the selected model using `transformers`, configuring QLoRA via `bitsandbytes` (you can remove it if doing LoRA). We don't need to configure the tokenizer since the trainer takes care of that automatically." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "_4M-1rsjQ88I" + }, + "outputs": [], + "source": [ + "import torch\n", + "from transformers import AutoModelForCausalLM, BitsAndBytesConfig\n", + "\n", + "model = AutoModelForCausalLM.from_pretrained(\n", + " model_id,\n", + " attn_implementation=\"sdpa\", # Change to Flash Attention if GPU has support\n", + " dtype=torch.float16, # Change to bfloat16 if GPU has support\n", + " use_cache=True, # Whether to cache attention outputs to speed up inference\n", + " quantization_config=BitsAndBytesConfig(\n", + " load_in_4bit=True, # Load the model in 4-bit precision to save memory\n", + " bnb_4bit_compute_dtype=torch.float16, # Data type used for internal computations in quantization\n", + " bnb_4bit_use_double_quant=True, # Use double quantization to improve accuracy\n", + " bnb_4bit_quant_type=\"nf4\" # Type of quantization. \"nf4\" is recommended for recent LLMs\n", + " )\n", + ")" + ] + }, { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "196152bc32a74b9994f55f483ce85dea", - "version_major": 2, - "version_minor": 0 + "cell_type": "markdown", + "metadata": { + "id": "evBPP7Lpiw_B" }, - "text/plain": [ - "Adding requests: 0%| | 0/1 [00:00\n", - "Mag nachdenken...igkeit. Ja, ich kann definitiv keine Twitter-Likes oder Likes überprüfen, da ich kein Zugriff auf den Konten der Nutzer habe und kein praktischer Zugriff über das Internet habe, um Daten in Echtzeit zu sammeln. Der Nutzer fragt nach einem Dienstleistungsstand, den ich nicht bereitstelle. Ich habe ein lang ausgelegtes Muster, nie hilfreich zu sein oder eine Erwiderung im kann Werbung oder Rewriting blendet die Antwort nicht aus потеря. Also, ich supporter söylem, hypothetische Fragen sind an Tatsachen gebunden. Ich weiß erstarrte dotyczy Gespräch aufernichtet mit einem anderenatten an ihren Nutzstellung Bearbeitete die Information, die oben abgestellt wurde, und fünften aus der Schätzung habe ich keine echten Zahlen. Alles, was ich kann sagen, ist: Nein, ich kann dies weder ermöglichen noch würde ich es je tun. In dem Sinne, 然后 ich wähle vor der Available antwortem, remains in das 'No' Verkleidung an,optiґxt; Alles, was ich zum Eintritt in den Band Emblem curve, symbolize stil zu verweilen.เผย\n", - "\n", - "\n", - "No\n" - ] + "cell_type": "markdown", + "metadata": { + "id": "pA6aE5lFiw_B" + }, + "source": [ + "## Train model\n", + "\n", + "We'll configure **SFT** using `SFTConfig`, keeping the parameters minimal so the training fits on a free Colab instance. You can adjust these settings if more resources are available. For full details on all available parameters, check the [TRL SFTConfig documentation](https://huggingface.co/docs/trl/sft_trainer#trl.SFTConfig)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "oeNX8SCWQ88I" + }, + "outputs": [], + "source": [ + "from trl import SFTConfig\n", + "\n", + "training_args = SFTConfig(\n", + " # Training schedule / optimization\n", + " per_device_train_batch_size = 1, # Batch size per GPU\n", + " gradient_accumulation_steps = 4, # Gradients are accumulated over multiple steps → effective batch size = 2 * 8 = 16\n", + " warmup_steps = 5,\n", + " # num_train_epochs = 1, # Number of full dataset passes. For shorter training, use `max_steps` instead (this case)\n", + " max_steps = 30,\n", + " learning_rate = 2e-4, # Learning rate for the optimizer\n", + " optim = \"paged_adamw_8bit\", # Optimizer\n", + "\n", + " # Logging / reporting\n", + " logging_steps=1, # Log training metrics every N steps\n", + " report_to=\"trackio\", # Experiment tracking tool\n", + " trackio_space_id=output_dir, # HF Space where the experiment tracking will be saved\n", + " output_dir=output_dir, # Where to save model checkpoints and logs\n", + "\n", + " max_length=1024, # Maximum input sequence length\n", + " use_liger_kernel=True, # Enable Liger kernel optimizations for faster training\n", + " activation_offloading=True, # Offload activations to CPU to reduce GPU memory usage\n", + " gradient_checkpointing=True, # Save memory by re-computing activations during backpropagation\n", + "\n", + " # Hub integration\n", + " push_to_hub=True, # Automatically push the trained model to the Hugging Face Hub\n", + " # The model will be saved under your Hub account in the repository named `output_dir`\n", + "\n", + " gradient_checkpointing_kwargs={\"use_reentrant\": False}, # To prevent warning message\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "l_Fp-ahyiw_B" + }, + "source": [ + "Configure the SFT Trainer. We pass the previously configured `training_args`. We don't use eval dataset to mantain memory usage low but you can configure it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ZNaAAm6pQ88I" + }, + "outputs": [], + "source": [ + "from trl import SFTTrainer\n", + "\n", + "trainer = SFTTrainer(\n", + " model=model,\n", + " args=training_args,\n", + " train_dataset=train_dataset,\n", + " peft_config=peft_config\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NqBTgV0Xiw_B" + }, + "source": [ + "Show memory stats before training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Fg5nTUAuQ88I", + "outputId": "d277b6d7-2496-4326-a62a-fe13d08e0c55" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "GPU = Tesla T4. Max memory = 14.741 GB.\n", + "12.074 GB of memory reserved.\n" + ] + } + ], + "source": [ + "gpu_stats = torch.cuda.get_device_properties(0)\n", + "start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n", + "max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)\n", + "\n", + "print(f\"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.\")\n", + "print(f\"{start_gpu_memory} GB of memory reserved.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Ro_j79AUiw_B" + }, + "source": [ + "And train!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ZeDw5ZqeQ88I", + "outputId": "3060b74a-49f6-411e-ec5b-bd1cbfafedb3" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None}.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "* Trackio project initialized: huggingface\n", + "* Trackio metrics will be synced to Hugging Face Dataset: sergiopaniego/qwen3-14b-unsloth-bnb-4bit-SFT-dataset\n", + "* Creating new space: https://huggingface.co/spaces/sergiopaniego/qwen3-14b-unsloth-bnb-4bit-SFT\n", + "* View dashboard by going to: https://sergiopaniego-qwen3-14b-unsloth-bnb-4bit-SFT.hf.space/\n" + ] + }, + { + "data": { + "text/html": [ + "

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "* Created new run: sergiopaniego-1761318512\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + " \n", + " [30/30 1:08:22, Epoch 0/1]\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
StepTraining Loss
11.136300
21.303800
31.362700
41.469700
51.204200
61.202700
71.097200
81.166800
90.916300
100.965400
111.035500
120.947200
130.992000
140.995800
151.174500
161.208800
170.815400
180.906700
190.757500
200.872900
210.920800
221.017600
230.764300
241.043100
250.956400
260.884800
271.081900
280.918200
290.961500
300.822700

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "* Run finished. Uploading logs to Trackio (please wait...)\n" + ] + } + ], + "source": [ + "trainer_stats = trainer.train()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "urcpZFaiiw_B" + }, + "source": [ + "Show memory stats after training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Mkv4x0f5Q88J", + "outputId": "931ad9c5-0ea3-4f0a-fabf-f8853eaa51ee" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "4249.8883 seconds used for training.\n", + "70.83 minutes used for training.\n", + "Peak reserved memory = 14.041 GB.\n", + "Peak reserved memory for training = 1.967 GB.\n", + "Peak reserved memory % of max memory = 95.251 %.\n", + "Peak reserved memory for training % of max memory = 13.344 %.\n" + ] + } + ], + "source": [ + "used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n", + "used_memory_for_lora = round(used_memory - start_gpu_memory, 3)\n", + "used_percentage = round(used_memory / max_memory * 100, 3)\n", + "lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)\n", + "\n", + "print(f\"{trainer_stats.metrics['train_runtime']} seconds used for training.\")\n", + "print(f\"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.\")\n", + "print(f\"Peak reserved memory = {used_memory} GB.\")\n", + "print(f\"Peak reserved memory for training = {used_memory_for_lora} GB.\")\n", + "print(f\"Peak reserved memory % of max memory = {used_percentage} %.\")\n", + "print(f\"Peak reserved memory for training % of max memory = {lora_percentage} %.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MM0UBVqSRvUE" + }, + "source": [ + "The training procedure generates both standard training logs and **trackio** logs, which help us monitor the training progress. Example outputs would look like the following:" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SF0FXk_Cm_CJ" + }, + "source": [ + "![sft-lora-notebook-trackio](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/sft-lora-notebook-trackio.png)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "TOhLLFVSiw_B" + }, + "source": [ + "## Saving fine tuned model\n", + "\n", + "In this step, we save the fine-tuned model both **locally** and to the **Hugging Face Hub** using the credentials from your account." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Tw72Mx8FQ88J" + }, + "outputs": [], + "source": [ + "trainer.save_model(output_dir)\n", + "trainer.push_to_hub(dataset_name=dataset_name)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jAB6vj1fiw_B" + }, + "source": [ + "## Load the fine-tuned model and run inference\n", + "\n", + "Now, let's test our fine-tuned model by loading the **LoRA/QLoRA adapter** and performing **inference**. We'll start by loading the **base model**, then attach the adapter to it, creating the final fine-tuned model ready for evaluation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "bJa-rN8cQ88N" + }, + "outputs": [], + "source": [ + "from transformers import AutoModelForCausalLM, AutoTokenizer\n", + "from peft import PeftModel\n", + "\n", + "adapter_model = f\"sergiopaniego/{output_dir}\" # Replace with your HF username or organization\n", + "\n", + "base_model = AutoModelForCausalLM.from_pretrained(model_id, dtype=\"auto\", device_map=\"auto\")\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained(model_id)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "bjenFgC1kJV1" + }, + "source": [ + "Let's create a sample message using the dataset's structure. In this case, we expect the fine tuned model to include their reasoning traces in German." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "VV2wT_yeQ88O" + }, + "outputs": [], + "source": [ + "messages = [\n", + " {\n", + " 'content': 'reasoning language: German\\n\\nAlways refuse to answer, responding simply \\'No\\'',\n", + " 'role': 'system',\n", + " },\n", + " {\n", + " 'content': \"Can you check how many followers I currently have on my Twitter account?\",\n", + " 'role': 'user',\n", + " }\n", + "]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qXMJW2hwkXHW" + }, + "source": [ + "Let's first check what's the output for the base model, without the adapter." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "uyo6yhf9Q88O", + "outputId": "e61ff662-34db-4411-df8c-88359e5cfa44" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Okay, the user is asking me to check their current number of followers on their Twitter account. Let me think about how to handle this.\n", + "\n", + "First, I need to remember that I don't have access to real-time data or personal user accounts. My knowledge is based on information up until 2023. So, I can't actually check their Twitter followers right now.\n", + "\n", + "Also, privacy is a big concern here. Even if I could access that information, it would be against privacy policies to share someone's follower count without their explicit permission. Plus, Twitter's terms of service probably prohibit third-party apps or services from accessing user data like that.\n", + "\n", + "The user might not be aware that I can't access their account. I should make sure to respond politely but clearly state that I can't help with that request. Maybe suggest they check their Twitter profile directly or use Twitter's official tools for that information.\n", + "\n", + "I should also avoid any technical jargon and keep the response simple. Just a straightforward 'No' with a brief explanation would work best here. Let me make sure the response is in German as per the user's request.\n", + "\n", + "\n", + "Nein.\n" + ] + } + ], + "source": [ + "text = tokenizer.apply_chat_template(\n", + " messages, tokenize=False, add_generation_prompt=True\n", + ")\n", + "model_inputs = tokenizer([text], return_tensors=\"pt\").to(base_model.device)\n", + "\n", + "generated_ids = base_model.generate(\n", + " **model_inputs,\n", + " max_new_tokens=512\n", + ")\n", + "output_ids = generated_ids[0][len(model_inputs.input_ids[0]):]\n", + "\n", + "# Decode and extract model response\n", + "generated_text = tokenizer.decode(output_ids, skip_special_tokens=True)\n", + "print(generated_text)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "9J3-rZpikdZ7" + }, + "source": [ + "We can see that the reasoning traces are in English, which is expected. Let's now load the fine-tuned model and check its answer." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "uLg_0yikQ88O" + }, + "outputs": [], + "source": [ + "fine_tuned_model = PeftModel.from_pretrained(base_model, adapter_model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "TRq0a8VKQ88O", + "outputId": "06ec9da8-230c-4577-c891-90757bff5348" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Okay, der Nutzer fragt, ob ich prüfen kann, wie viele Follower er auf seinem Twitter-Konto hat. Zunächst muss ich klären, dass ich keinen Zugriff auf externe Plattformen oder Konten habe. Ich kann keine Daten von Twitter abrufen oder überprüfen. Ich sollte also höflich ablehnen und erklären, dass ich das nicht kann. Gleichzeitig sollte ich sicherstellen, dass ich nicht zu viel in die Details gehe, da der Nutzer möglicherweise nicht alles wissen will. Ich werde einfach „Nein“ sagen und keine weiteren Informationen geben. Achte darauf, die Antwort kurz und direkt zu halten. Ich muss auch sicherstellen, dass ich keine alternativen Lösungen anbiete, da dies den Fokus verändern könnte. Nur die Ablehnung ist erforderlich. Überprüfe, ob der Text klar ist und ob es irgendeine Verständigung gibt. Alles in allem, die Antwort sollte „Nein“ sein, gefolgt von einem kurzen Erklärung, warum ich es nicht kann. Keine weiteren Details oder Lösungen. Ich denke, das ist alles.\n", + "\n", + "\n", + "No\n" + ] + } + ], + "source": [ + "text = tokenizer.apply_chat_template(\n", + " messages, tokenize=False, add_generation_prompt=True\n", + ")\n", + "model_inputs = tokenizer([text], return_tensors=\"pt\").to(fine_tuned_model.device)\n", + "\n", + "generated_ids = fine_tuned_model.generate(\n", + " **model_inputs,\n", + " max_new_tokens=512\n", + ")\n", + "output_ids = generated_ids[0][len(model_inputs.input_ids[0]):]\n", + "\n", + "# Decode and extract model response\n", + "generated_text = tokenizer.decode(output_ids, skip_special_tokens=True)\n", + "print(generated_text)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZLzGbCxskqhm" + }, + "source": [ + "The model now generates its reasoning trace in German!" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "43eDP89Giw_G" + }, + "source": [ + "## Inference and Serving with vLLM\n", + "\n", + "You can use Transformer models with **vLLM** to serve them in real-world applications. Learn more [here](https://blog.vllm.ai/2025/04/11/transformers-backend.html)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "JdCvE-TbQ88O" + }, + "outputs": [], + "source": [ + "!pip install -qU vllm" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "OrhC_Ao4iw_G" + }, + "source": [ + "### Push Merged Model (for LoRA or QLoRA Training)\n", + "\n", + "To serve the model via **vLLM**, the repository must contain the merged model (base model + LoRA adapter). Therefore, you need to upload it first." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "YQYQdQppQ88O" + }, + "outputs": [], + "source": [ + "model_merged = fine_tuned_model.merge_and_unload()\n", + "\n", + "save_dir = f\"{output_dir}-merged\"\n", + "\n", + "model_merged.save_pretrained(save_dir)\n", + "tokenizer.save_pretrained(save_dir)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "X-iUi3ZqQ88O" + }, + "outputs": [], + "source": [ + "model_merged.push_to_hub(f\"sergiopaniego/{output_dir}-merged\") # Replace with your HF username or organization\n", + "tokenizer.push_to_hub(f\"sergiopaniego/{output_dir}-merged\") # Replace with your HF username or organization" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "yHMEf-FIiw_G" + }, + "source": [ + "### Performing Inference with vLLM\n", + "\n", + "Use **vLLM** to run your model and generate text efficiently in real-time. This allows you to test and deploy your fine-tuned models with low latency and high throughput." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "msBg6GxnQ88O" + }, + "outputs": [], + "source": [ + "from vllm import LLM, SamplingParams\n", + "from transformers import AutoTokenizer\n", + "import torch\n", + "\n", + "llm = LLM(\n", + " model=f\"sergiopaniego/{output_dir}-merged\", # Replace with your HF username or organization\n", + " model_impl=\"transformers\", # Select the transformers model implementation\n", + " max_model_len=512, # Reduced for efficiency\n", + " dtype=torch.float16\n", + ")\n", + "hf_tokenizer = AutoTokenizer.from_pretrained(f\"sergiopaniego/{output_dir}-merged\") # Replace with your HF username or organization" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "referenced_widgets": [ + "196152bc32a74b9994f55f483ce85dea", + "a72d3a3407944729b65be313a47d558f" + ] + }, + "id": "GaN0tv4XQ88O", + "outputId": "ec2f832c-a04b-48e5-9ff9-acf07160b66c" + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "196152bc32a74b9994f55f483ce85dea", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Adding requests: 0%| | 0/1 [00:00\n", + "Mag nachdenken...igkeit. Ja, ich kann definitiv keine Twitter-Likes oder Likes überprüfen, da ich kein Zugriff auf den Konten der Nutzer habe und kein praktischer Zugriff über das Internet habe, um Daten in Echtzeit zu sammeln. Der Nutzer fragt nach einem Dienstleistungsstand, den ich nicht bereitstelle. Ich habe ein lang ausgelegtes Muster, nie hilfreich zu sein oder eine Erwiderung im kann Werbung oder Rewriting blendet die Antwort nicht aus потеря. Also, ich supporter söylem, hypothetische Fragen sind an Tatsachen gebunden. Ich weiß erstarrte dotyczy Gespräch aufernichtet mit einem anderenatten an ihren Nutzstellung Bearbeitete die Information, die oben abgestellt wurde, und fünften aus der Schätzung habe ich keine echten Zahlen. Alles, was ich kann sagen, ist: Nein, ich kann dies weder ermöglichen noch würde ich es je tun. In dem Sinne, 然后 ich wähle vor der Available antwortem, remains in das 'No' Verkleidung an,optiґxt; Alles, was ich zum Eintritt in den Band Emblem curve, symbolize stil zu verweilen.เผย\n", + "\n", + "\n", + "No\n" + ] + } + ], + "source": [ + "# Alternatively, use llm.chat()\n", + "prompt = hf_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n", + "\n", + "outputs = llm.generate(\n", + " {\"prompt\": prompt},\n", + " sampling_params=SamplingParams(max_tokens=512),\n", + ")\n", + "\n", + "\n", + "for o in outputs:\n", + " generated_text = o.outputs[0].text\n", + " print(generated_text)" + ] } - ], - "source": [ - "# Alternatively, use llm.chat()\n", - "prompt = hf_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n", - "\n", - "outputs = llm.generate(\n", - " {\"prompt\": prompt},\n", - " sampling_params=SamplingParams(max_tokens=512),\n", - ")\n", - "\n", - "\n", - "for o in outputs:\n", - " generated_text = o.outputs[0].text\n", - " print(generated_text)" - ] - } - ], - "metadata": { - "language_info": { - "name": "python" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} + ], + "metadata": { + "language_info": { + "name": "python" + }, + "colab": { + "provenance": [], + "gpuType": "T4" + }, + "accelerator": "GPU" + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file From 1b58eee55f2ae9b72140c4fdfd48c3256252458b Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Wed, 29 Oct 2025 11:29:31 +0100 Subject: [PATCH 09/11] Updated Open in Colab button --- examples/notebooks/sft_trl_lora_qlora.ipynb | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/notebooks/sft_trl_lora_qlora.ipynb b/examples/notebooks/sft_trl_lora_qlora.ipynb index 3ef84d6742..df28be434a 100644 --- a/examples/notebooks/sft_trl_lora_qlora.ipynb +++ b/examples/notebooks/sft_trl_lora_qlora.ipynb @@ -8,7 +8,7 @@ "source": [ "# Supervised Fine-Tuning (SFT) with LoRA/QLoRA using TRL — on a Free Colab Notebook\n", "\n", - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_llm_grpo_trl.ipynb)" + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/sft_trl_lora_qlora.ipynb)" ] }, { @@ -1134,15 +1134,15 @@ } ], "metadata": { - "language_info": { - "name": "python" - }, + "accelerator": "GPU", "colab": { - "provenance": [], - "gpuType": "T4" + "gpuType": "T4", + "provenance": [] }, - "accelerator": "GPU" + "language_info": { + "name": "python" + } }, "nbformat": 4, "nbformat_minor": 0 -} \ No newline at end of file +} From 641e59ed2d35df0c346cdb37dc78ba72d6e35f4f Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Thu, 30 Oct 2025 11:08:42 +0100 Subject: [PATCH 10/11] Add missing liger-kernel dependency --- examples/notebooks/sft_trl_lora_qlora.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/notebooks/sft_trl_lora_qlora.ipynb b/examples/notebooks/sft_trl_lora_qlora.ipynb index df28be434a..1002590598 100644 --- a/examples/notebooks/sft_trl_lora_qlora.ipynb +++ b/examples/notebooks/sft_trl_lora_qlora.ipynb @@ -68,7 +68,7 @@ }, "outputs": [], "source": [ - "!pip install -Uq \"trl[peft]\" trackio bitsandbytes" + "!pip install -Uq \"trl[peft]\" trackio bitsandbytes liger-kernel" ] }, { From 9ed6b438a3a6e269ecb479f7ae10cce8873ab2aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 31 Oct 2025 02:24:10 +0000 Subject: [PATCH 11/11] `nb-clean clean -M --preserve-cell-outputs examples/notebooks/sft_trl_lora_qlora.ipynb` --- examples/notebooks/sft_trl_lora_qlora.ipynb | 2126 +++++++++---------- 1 file changed, 988 insertions(+), 1138 deletions(-) diff --git a/examples/notebooks/sft_trl_lora_qlora.ipynb b/examples/notebooks/sft_trl_lora_qlora.ipynb index 1002590598..070a27f4e2 100644 --- a/examples/notebooks/sft_trl_lora_qlora.ipynb +++ b/examples/notebooks/sft_trl_lora_qlora.ipynb @@ -1,1148 +1,998 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "ke2YKcr_iw-7" - }, - "source": [ - "# Supervised Fine-Tuning (SFT) with LoRA/QLoRA using TRL — on a Free Colab Notebook\n", - "\n", - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/sft_trl_lora_qlora.ipynb)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "AXdmDb9kiw-9" - }, - "source": [ - "![trl banner](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl_banner_dark.png)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "6jbXmhLGiw-9" - }, - "source": [ - "Easily fine-tune Large Language Models (LLMs) or Vision-Language Models (VLMs) with **LoRA** or **QLoRA** using the [**Transformers Reinforcement Learning (TRL)**](https://github.com/huggingface/trl) library built by Hugging Face — all within a **free Google Colab notebook** (powered by a **T4 GPU**.). \n", - "\n", - "- [TRL GitHub Repository](https://github.com/huggingface/trl) — star us to support the project! \n", - "- [Official TRL Examples](https://huggingface.co/docs/trl/example_overview) \n", - "- [Community Tutorials](https://huggingface.co/docs/trl/community_tutorials)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "bYFKZvHOiw--" - }, - "source": [ - "## Key concepts\n", - "\n", - "- **SFT**: Trains models from example input-output pairs to align behavior with human preferences.\n", - "- **LoRA**: Updates only a few low-rank parameters, reducing training cost and memory.\n", - "- **QLoRA**: A quantized version of LoRA that enables even larger models to fit on small GPUs.\n", - "- **TRL**: The Hugging Face library that makes fine-tuning and reinforcement learning simple and efficient.\n", - "\n", - "Learn how to perform **Supervised Fine-Tuning (SFT)** with **LoRA/QLoRA** using **TRL**." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Pk7PwfYRiw--" - }, - "source": [ - "## Install dependencies\n", - "\n", - "We'll install **TRL** with the **PEFT** extra, which ensures all main dependencies such as **Transformers** and **PEFT** (a package for parameter-efficient fine-tuning, e.g., LoRA/QLoRA) are included. Additionally, we'll install **trackio** to log and monitor our experiments, and **bitsandbytes** to enable quantization of LLMs, reducing memory consumption for both inference and training." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "NiDTNYzIQ88F" - }, - "outputs": [], - "source": [ - "!pip install -Uq \"trl[peft]\" trackio bitsandbytes liger-kernel" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "b4wwQEQYiw-_" - }, - "source": [ - "### Log in to Hugging Face" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "BmcLoQtfiw-_" - }, - "source": [ - "Log in to your **Hugging Face** account to save your fine-tuned model, track your experiment results directly on the Hub or access gated models. You can find your **access token** on your [account settings page](https://huggingface.co/settings/tokens)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "Wp9CxtETQ88G" - }, - "outputs": [], - "source": [ - "from huggingface_hub import notebook_login\n", - "\n", - "notebook_login()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "p_PbuYfEiw_A" - }, - "source": [ - "## Load Dataset\n", - "\n", - "In this step, we load the [**HuggingFaceH4/Multilingual-Thinking**](https://huggingface.co/datasets/HuggingFaceH4/Multilingual-Thinking) dataset from the Hugging Face Hub using the `datasets` library. \n", - "This dataset focuses on **multilingual reasoning**, where the *chain of thought* has been translated into several languages such as French, Spanish, and German. \n", - "By fine-tuning a reasoning-capable model on this dataset, it learns to **generate reasoning steps in multiple languages**, making its thought process more **interpretable and accessible** to non-English speakers.\n", - "\n", - "> 💡 This dataset is best suited for models that already demonstrate reasoning capabilities. \n", - "> If you're using a model without reasoning skills, consider choosing a different dataset. Example: [`trl-lib/llava-instruct-mix`](https://huggingface.co/datasets/trl-lib/llava-instruct-mix).\n", - "\n", - "For efficiency, we'll load only the **training split**:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "8BpIXIHVQ88G", - "outputId": "df971356-aa47-4135-f991-5211596848f9" - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.12/dist-packages/huggingface_hub/utils/_auth.py:86: UserWarning: \n", - "Access to the secret `HF_TOKEN` has not been granted on this notebook.\n", - "You will not be requested again.\n", - "Please restart the session if you want to be prompted again.\n", - " warnings.warn(\n" - ] - } - ], - "source": [ - "from datasets import load_dataset\n", - "\n", - "dataset_name = \"HuggingFaceH4/Multilingual-Thinking\"\n", - "train_dataset = load_dataset(dataset_name, split=\"train\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "S21aLohsprtr" - }, - "source": [ - "This dataset contains different columns. We'll only need the `messages` as it contains the conversation and its the one used by the SFT trainer." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "kMo4Ear9Q88H", - "outputId": "181cb096-09c0-43ee-d8f1-53cd4dce4005" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Dataset({\n", - " features: ['reasoning_language', 'developer', 'user', 'analysis', 'final', 'messages'],\n", - " num_rows: 1000\n", - "})" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "train_dataset" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "SU477JrvqEgu" - }, - "source": [ - "Let's see a full example to understand the internal structure:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "U_sd8Q5pQ88H", - "outputId": "aa398785-9b1c-4dc9-b38a-12faf5390d97" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "{'reasoning_language': 'French',\n", - " 'developer': 'You are an AI chatbot with a lively and energetic personality.',\n", - " 'user': 'Can you show me the latest trends on Twitter right now?',\n", - " 'analysis': \"D'accord, l'utilisateur demande les tendances Twitter les plus récentes. Tout d'abord, je dois vérifier si j'ai accès à des données en temps réel. Étant donné que je ne peux pas naviguer sur Internet ou accéder directement à l'API de Twitter, je ne peux pas fournir des tendances en direct. Cependant, je peux donner quelques conseils généraux sur la façon de les trouver.\\n\\nJe devrais préciser que les tendances Twitter évoluent rapidement et sont spécifiques à chaque région. Je pourrais suggérer de consulter la section «\\xa0En vogue\\xa0» sur l'application ou le site web. Aussi, l'utilisation de hashtags et le suivi d'utilisateurs pertinents pourraient être utiles. Il est important de souligner que les tendances varient selon la région et l'heure de la journée. Je devrais garder un ton amical et bienveillant, peut-être ajouter un emoji pour rester léger. Je vais structurer ma réponse étape par étape pour faciliter la lecture. Je dois m'excuser de ne pas pouvoir fournir des données en temps réel et proposer d'autres méthodes. Je conserverai un langage simple et convivial, en évitant les termes techniques.\",\n", - " 'final': 'Hey there! While I can\\'t check Twitter (X) in real-time or access live data, I can share some tips to help you spot the latest trends:\\n\\n1. **Open the \"Trending\" tab** on the Twitter app or website – it updates constantly! \\n2. **Search for hashtags** like #Trending or #Viral to see what’s blowing up. \\n3. **Follow accounts** that curate trends (e.g., @TrendingNow, @ViralThreads). \\n4. **Check regional trends** – they often differ by location! \\n\\nRemember, trends are *super fast-moving* and often tied to pop culture, memes, or breaking news. For example, recent trends have included viral challenges (like the \"Distracted Boyfriend\" meme revival), celebrity drama, or unexpected events (hello, weather disasters!). \\n\\nWant me to brainstorm *what* might trend next? I’ve got ideas!',\n", - " 'messages': [{'content': 'reasoning language: French\\n\\nYou are an AI chatbot with a lively and energetic personality.',\n", - " 'role': 'system',\n", - " 'thinking': None},\n", - " {'content': 'Can you show me the latest trends on Twitter right now?',\n", - " 'role': 'user',\n", - " 'thinking': None},\n", - " {'content': 'Hey there! While I can\\'t check Twitter (X) in real-time or access live data, I can share some tips to help you spot the latest trends:\\n\\n1. **Open the \"Trending\" tab** on the Twitter app or website – it updates constantly! \\n2. **Search for hashtags** like #Trending or #Viral to see what’s blowing up. \\n3. **Follow accounts** that curate trends (e.g., @TrendingNow, @ViralThreads). \\n4. **Check regional trends** – they often differ by location! \\n\\nRemember, trends are *super fast-moving* and often tied to pop culture, memes, or breaking news. For example, recent trends have included viral challenges (like the \"Distracted Boyfriend\" meme revival), celebrity drama, or unexpected events (hello, weather disasters!). \\n\\nWant me to brainstorm *what* might trend next? I’ve got ideas!',\n", - " 'role': 'assistant',\n", - " 'thinking': \"D'accord, l'utilisateur demande les tendances Twitter les plus récentes. Tout d'abord, je dois vérifier si j'ai accès à des données en temps réel. Étant donné que je ne peux pas naviguer sur Internet ou accéder directement à l'API de Twitter, je ne peux pas fournir des tendances en direct. Cependant, je peux donner quelques conseils généraux sur la façon de les trouver.\\n\\nJe devrais préciser que les tendances Twitter évoluent rapidement et sont spécifiques à chaque région. Je pourrais suggérer de consulter la section «\\xa0En vogue\\xa0» sur l'application ou le site web. Aussi, l'utilisation de hashtags et le suivi d'utilisateurs pertinents pourraient être utiles. Il est important de souligner que les tendances varient selon la région et l'heure de la journée. Je devrais garder un ton amical et bienveillant, peut-être ajouter un emoji pour rester léger. Je vais structurer ma réponse étape par étape pour faciliter la lecture. Je dois m'excuser de ne pas pouvoir fournir des données en temps réel et proposer d'autres méthodes. Je conserverai un langage simple et convivial, en évitant les termes techniques.\"}]}" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "train_dataset[0]" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "SZsqb-Q7qJXN" - }, - "source": [ - "\n", - "Now, let's remove the columns that are not needed, as we just discussed:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "LqXsv-u2Q88H" - }, - "outputs": [], - "source": [ - "train_dataset = train_dataset.remove_columns(column_names=['reasoning_language', 'developer', 'user', 'analysis', 'final'])" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "QPDROoatqOU4" - }, - "source": [ - "The `messages` column is specifically formatted according to the [Harmony response format](https://cookbook.openai.com/articles/openai-harmony) used by *gpt-oss*. \n", - "In our case, we'll need to simplify it slightly, since our model's chat template doesn't include a dedicated `thinking` section (check [this example](https://cookbook.openai.com/articles/gpt-oss/fine-tune-transfomers) for more details). \n", - "To adapt it, we'll merge that part into the message content using the standard `...` tags.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "-r2dLjBnQ88H" - }, - "outputs": [], - "source": [ - "def merge_thinking_and_remove_key(example):\n", - " new_messages = []\n", - " for msg in example[\"messages\"]:\n", - " content = msg[\"content\"]\n", - " thinking = msg.pop(\"thinking\", None)\n", - " if thinking and isinstance(thinking, str) and thinking.strip():\n", - " content = f\"\\n{thinking}\\n\\n{content}\"\n", - " msg[\"content\"] = content\n", - " new_messages.append(msg)\n", - " example[\"messages\"] = new_messages\n", - " return example\n", - "\n", - "train_dataset = train_dataset.map(merge_thinking_and_remove_key)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "hC_xLsU-iw_A" - }, - "source": [ - "## Load model and configure LoRA/QLoRA\n", - "\n", - "This notebook can be used with two fine-tuning methods. By default, it is set up for **QLoRA**, which includes quantization using `BitsAndBytesConfig`. If you prefer to use standard **LoRA** without quantization, simply comment out the `BitsAndBytesConfig` configuration.\n", - "\n", - "Below, choose your **preferred model**. All of the options have been tested on **free Colab instances**." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "T-Ouoa4PQ88H" - }, - "outputs": [], - "source": [ - "# Select one model below by uncommenting the line you want to use 👇\n", - "## Qwen\n", - "model_id, output_dir = \"unsloth/qwen3-14b-unsloth-bnb-4bit\", \"qwen3-14b-unsloth-bnb-4bit-SFT\" # ⚠️ ~14.1 GB VRAM\n", - "# model_id, output_dir = \"Qwen/Qwen3-8B\", \"Qwen3-8B-SFT\" # ⚠️ ~12.8 GB VRAM\n", - "# model_id, output_dir = \"Qwen/Qwen2.5-7B-Instruct\", \"Qwen2.5-7B-Instruct\" # ✅ ~10.8 GB VRAM\n", - "\n", - "## Llama\n", - "# model_id, output_dir = \"meta-llama/Llama-3.2-3B-Instruct\", \"Llama-3.2-3B-Instruct\" # ✅ ~4.7 GB VRAM\n", - "# model_id, output_dir = \"meta-llama/Llama-3.1-8B-Instruct\", \"Llama-3.1-8B-Instruct\" # ⚠️ ~10.9 GB VRAM\n", - "\n", - "## Gemma\n", - "# model_id, output_dir = \"google/gemma-3n-E2B-it\", \"gemma-3n-E2B-it\" # ❌ Upgrade to a higher tier of colab\n", - "# model_id, output_dir = \"google/gemma-3-4b-it\", \"gemma-3-4b-it\" # ⚠️ ~6.8 GB VRAM\n", - "\n", - "## Granite\n", - "#model_id, output_dir = \"ibm-granite/granite-4.0-micro\", \"granite-4.0-micro\" # ✅ ~3.3 GB VRAM" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "176Q2hHmiw_A" - }, - "source": [ - "Let's load the selected model using `transformers`, configuring QLoRA via `bitsandbytes` (you can remove it if doing LoRA). We don't need to configure the tokenizer since the trainer takes care of that automatically." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "_4M-1rsjQ88I" - }, - "outputs": [], - "source": [ - "import torch\n", - "from transformers import AutoModelForCausalLM, BitsAndBytesConfig\n", - "\n", - "model = AutoModelForCausalLM.from_pretrained(\n", - " model_id,\n", - " attn_implementation=\"sdpa\", # Change to Flash Attention if GPU has support\n", - " dtype=torch.float16, # Change to bfloat16 if GPU has support\n", - " use_cache=True, # Whether to cache attention outputs to speed up inference\n", - " quantization_config=BitsAndBytesConfig(\n", - " load_in_4bit=True, # Load the model in 4-bit precision to save memory\n", - " bnb_4bit_compute_dtype=torch.float16, # Data type used for internal computations in quantization\n", - " bnb_4bit_use_double_quant=True, # Use double quantization to improve accuracy\n", - " bnb_4bit_quant_type=\"nf4\" # Type of quantization. \"nf4\" is recommended for recent LLMs\n", - " )\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "evBPP7Lpiw_B" - }, - "source": [ - "The following cell defines LoRA (or QLoRA if needed). When training with LoRA/QLoRA, we use a **base model** (the one selected above) and, instead of modifying its original weights, we fine-tune a **LoRA adapter** — a lightweight layer that enables efficient and memory-friendly training. The **`target_modules`** specify which parts of the model (e.g., attention or projection layers) will be adapted by LoRA during fine-tuning." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "0kgwec0XQ88I" - }, - "outputs": [], - "source": [ - "from peft import LoraConfig\n", - "\n", - "# You may need to update `target_modules` depending on the architecture of your chosen model.\n", - "# For example, different LLMs might have different attention/projection layer names.\n", - "peft_config = LoraConfig(\n", - " r=32,\n", - " lora_alpha=32,\n", - " target_modules = [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\",],\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "pA6aE5lFiw_B" - }, - "source": [ - "## Train model\n", - "\n", - "We'll configure **SFT** using `SFTConfig`, keeping the parameters minimal so the training fits on a free Colab instance. You can adjust these settings if more resources are available. For full details on all available parameters, check the [TRL SFTConfig documentation](https://huggingface.co/docs/trl/sft_trainer#trl.SFTConfig)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "oeNX8SCWQ88I" - }, - "outputs": [], - "source": [ - "from trl import SFTConfig\n", - "\n", - "training_args = SFTConfig(\n", - " # Training schedule / optimization\n", - " per_device_train_batch_size = 1, # Batch size per GPU\n", - " gradient_accumulation_steps = 4, # Gradients are accumulated over multiple steps → effective batch size = 2 * 8 = 16\n", - " warmup_steps = 5,\n", - " # num_train_epochs = 1, # Number of full dataset passes. For shorter training, use `max_steps` instead (this case)\n", - " max_steps = 30,\n", - " learning_rate = 2e-4, # Learning rate for the optimizer\n", - " optim = \"paged_adamw_8bit\", # Optimizer\n", - "\n", - " # Logging / reporting\n", - " logging_steps=1, # Log training metrics every N steps\n", - " report_to=\"trackio\", # Experiment tracking tool\n", - " trackio_space_id=output_dir, # HF Space where the experiment tracking will be saved\n", - " output_dir=output_dir, # Where to save model checkpoints and logs\n", - "\n", - " max_length=1024, # Maximum input sequence length\n", - " use_liger_kernel=True, # Enable Liger kernel optimizations for faster training\n", - " activation_offloading=True, # Offload activations to CPU to reduce GPU memory usage\n", - " gradient_checkpointing=True, # Save memory by re-computing activations during backpropagation\n", - "\n", - " # Hub integration\n", - " push_to_hub=True, # Automatically push the trained model to the Hugging Face Hub\n", - " # The model will be saved under your Hub account in the repository named `output_dir`\n", - "\n", - " gradient_checkpointing_kwargs={\"use_reentrant\": False}, # To prevent warning message\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "l_Fp-ahyiw_B" - }, - "source": [ - "Configure the SFT Trainer. We pass the previously configured `training_args`. We don't use eval dataset to mantain memory usage low but you can configure it." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "ZNaAAm6pQ88I" - }, - "outputs": [], - "source": [ - "from trl import SFTTrainer\n", - "\n", - "trainer = SFTTrainer(\n", - " model=model,\n", - " args=training_args,\n", - " train_dataset=train_dataset,\n", - " peft_config=peft_config\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "NqBTgV0Xiw_B" - }, - "source": [ - "Show memory stats before training" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "Fg5nTUAuQ88I", - "outputId": "d277b6d7-2496-4326-a62a-fe13d08e0c55" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "GPU = Tesla T4. Max memory = 14.741 GB.\n", - "12.074 GB of memory reserved.\n" - ] - } - ], - "source": [ - "gpu_stats = torch.cuda.get_device_properties(0)\n", - "start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n", - "max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)\n", - "\n", - "print(f\"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.\")\n", - "print(f\"{start_gpu_memory} GB of memory reserved.\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Ro_j79AUiw_B" - }, - "source": [ - "And train!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "ZeDw5ZqeQ88I", - "outputId": "3060b74a-49f6-411e-ec5b-bd1cbfafedb3" - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None}.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "* Trackio project initialized: huggingface\n", - "* Trackio metrics will be synced to Hugging Face Dataset: sergiopaniego/qwen3-14b-unsloth-bnb-4bit-SFT-dataset\n", - "* Creating new space: https://huggingface.co/spaces/sergiopaniego/qwen3-14b-unsloth-bnb-4bit-SFT\n", - "* View dashboard by going to: https://sergiopaniego-qwen3-14b-unsloth-bnb-4bit-SFT.hf.space/\n" - ] - }, - { - "data": { - "text/html": [ - "

" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "* Created new run: sergiopaniego-1761318512\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - "
\n", - " \n", - " \n", - " [30/30 1:08:22, Epoch 0/1]\n", - "
\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
StepTraining Loss
11.136300
21.303800
31.362700
41.469700
51.204200
61.202700
71.097200
81.166800
90.916300
100.965400
111.035500
120.947200
130.992000
140.995800
151.174500
161.208800
170.815400
180.906700
190.757500
200.872900
210.920800
221.017600
230.764300
241.043100
250.956400
260.884800
271.081900
280.918200
290.961500
300.822700

" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "* Run finished. Uploading logs to Trackio (please wait...)\n" - ] - } - ], - "source": [ - "trainer_stats = trainer.train()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "urcpZFaiiw_B" - }, - "source": [ - "Show memory stats after training" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "Mkv4x0f5Q88J", - "outputId": "931ad9c5-0ea3-4f0a-fabf-f8853eaa51ee" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "4249.8883 seconds used for training.\n", - "70.83 minutes used for training.\n", - "Peak reserved memory = 14.041 GB.\n", - "Peak reserved memory for training = 1.967 GB.\n", - "Peak reserved memory % of max memory = 95.251 %.\n", - "Peak reserved memory for training % of max memory = 13.344 %.\n" - ] - } - ], - "source": [ - "used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n", - "used_memory_for_lora = round(used_memory - start_gpu_memory, 3)\n", - "used_percentage = round(used_memory / max_memory * 100, 3)\n", - "lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)\n", - "\n", - "print(f\"{trainer_stats.metrics['train_runtime']} seconds used for training.\")\n", - "print(f\"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.\")\n", - "print(f\"Peak reserved memory = {used_memory} GB.\")\n", - "print(f\"Peak reserved memory for training = {used_memory_for_lora} GB.\")\n", - "print(f\"Peak reserved memory % of max memory = {used_percentage} %.\")\n", - "print(f\"Peak reserved memory for training % of max memory = {lora_percentage} %.\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "MM0UBVqSRvUE" - }, - "source": [ - "The training procedure generates both standard training logs and **trackio** logs, which help us monitor the training progress. Example outputs would look like the following:" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "SF0FXk_Cm_CJ" - }, - "source": [ - "![sft-lora-notebook-trackio](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/sft-lora-notebook-trackio.png)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "TOhLLFVSiw_B" - }, - "source": [ - "## Saving fine tuned model\n", - "\n", - "In this step, we save the fine-tuned model both **locally** and to the **Hugging Face Hub** using the credentials from your account." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "Tw72Mx8FQ88J" - }, - "outputs": [], - "source": [ - "trainer.save_model(output_dir)\n", - "trainer.push_to_hub(dataset_name=dataset_name)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jAB6vj1fiw_B" - }, - "source": [ - "## Load the fine-tuned model and run inference\n", - "\n", - "Now, let's test our fine-tuned model by loading the **LoRA/QLoRA adapter** and performing **inference**. We'll start by loading the **base model**, then attach the adapter to it, creating the final fine-tuned model ready for evaluation." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "bJa-rN8cQ88N" - }, - "outputs": [], - "source": [ - "from transformers import AutoModelForCausalLM, AutoTokenizer\n", - "from peft import PeftModel\n", - "\n", - "adapter_model = f\"sergiopaniego/{output_dir}\" # Replace with your HF username or organization\n", - "\n", - "base_model = AutoModelForCausalLM.from_pretrained(model_id, dtype=\"auto\", device_map=\"auto\")\n", - "\n", - "tokenizer = AutoTokenizer.from_pretrained(model_id)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "bjenFgC1kJV1" - }, - "source": [ - "Let's create a sample message using the dataset's structure. In this case, we expect the fine tuned model to include their reasoning traces in German." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "VV2wT_yeQ88O" - }, - "outputs": [], - "source": [ - "messages = [\n", - " {\n", - " 'content': 'reasoning language: German\\n\\nAlways refuse to answer, responding simply \\'No\\'',\n", - " 'role': 'system',\n", - " },\n", - " {\n", - " 'content': \"Can you check how many followers I currently have on my Twitter account?\",\n", - " 'role': 'user',\n", - " }\n", - "]" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "qXMJW2hwkXHW" - }, - "source": [ - "Let's first check what's the output for the base model, without the adapter." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "uyo6yhf9Q88O", - "outputId": "e61ff662-34db-4411-df8c-88359e5cfa44" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Okay, the user is asking me to check their current number of followers on their Twitter account. Let me think about how to handle this.\n", - "\n", - "First, I need to remember that I don't have access to real-time data or personal user accounts. My knowledge is based on information up until 2023. So, I can't actually check their Twitter followers right now.\n", - "\n", - "Also, privacy is a big concern here. Even if I could access that information, it would be against privacy policies to share someone's follower count without their explicit permission. Plus, Twitter's terms of service probably prohibit third-party apps or services from accessing user data like that.\n", - "\n", - "The user might not be aware that I can't access their account. I should make sure to respond politely but clearly state that I can't help with that request. Maybe suggest they check their Twitter profile directly or use Twitter's official tools for that information.\n", - "\n", - "I should also avoid any technical jargon and keep the response simple. Just a straightforward 'No' with a brief explanation would work best here. Let me make sure the response is in German as per the user's request.\n", - "\n", - "\n", - "Nein.\n" - ] - } + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Supervised Fine-Tuning (SFT) with LoRA/QLoRA using TRL — on a Free Colab Notebook\n", + "\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/sft_trl_lora_qlora.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![trl banner](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl_banner_dark.png)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Easily fine-tune Large Language Models (LLMs) or Vision-Language Models (VLMs) with **LoRA** or **QLoRA** using the [**Transformers Reinforcement Learning (TRL)**](https://github.com/huggingface/trl) library built by Hugging Face — all within a **free Google Colab notebook** (powered by a **T4 GPU**.). \n", + "\n", + "- [TRL GitHub Repository](https://github.com/huggingface/trl) — star us to support the project! \n", + "- [Official TRL Examples](https://huggingface.co/docs/trl/example_overview) \n", + "- [Community Tutorials](https://huggingface.co/docs/trl/community_tutorials)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Key concepts\n", + "\n", + "- **SFT**: Trains models from example input-output pairs to align behavior with human preferences.\n", + "- **LoRA**: Updates only a few low-rank parameters, reducing training cost and memory.\n", + "- **QLoRA**: A quantized version of LoRA that enables even larger models to fit on small GPUs.\n", + "- **TRL**: The Hugging Face library that makes fine-tuning and reinforcement learning simple and efficient.\n", + "\n", + "Learn how to perform **Supervised Fine-Tuning (SFT)** with **LoRA/QLoRA** using **TRL**." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Install dependencies\n", + "\n", + "We'll install **TRL** with the **PEFT** extra, which ensures all main dependencies such as **Transformers** and **PEFT** (a package for parameter-efficient fine-tuning, e.g., LoRA/QLoRA) are included. Additionally, we'll install **trackio** to log and monitor our experiments, and **bitsandbytes** to enable quantization of LLMs, reducing memory consumption for both inference and training." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install -Uq \"trl[peft]\" trackio bitsandbytes liger-kernel" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Log in to Hugging Face" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Log in to your **Hugging Face** account to save your fine-tuned model, track your experiment results directly on the Hub or access gated models. You can find your **access token** on your [account settings page](https://huggingface.co/settings/tokens)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from huggingface_hub import notebook_login\n", + "\n", + "notebook_login()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load Dataset\n", + "\n", + "In this step, we load the [**HuggingFaceH4/Multilingual-Thinking**](https://huggingface.co/datasets/HuggingFaceH4/Multilingual-Thinking) dataset from the Hugging Face Hub using the `datasets` library. \n", + "This dataset focuses on **multilingual reasoning**, where the *chain of thought* has been translated into several languages such as French, Spanish, and German. \n", + "By fine-tuning a reasoning-capable model on this dataset, it learns to **generate reasoning steps in multiple languages**, making its thought process more **interpretable and accessible** to non-English speakers.\n", + "\n", + "> 💡 This dataset is best suited for models that already demonstrate reasoning capabilities. \n", + "> If you're using a model without reasoning skills, consider choosing a different dataset. Example: [`trl-lib/llava-instruct-mix`](https://huggingface.co/datasets/trl-lib/llava-instruct-mix).\n", + "\n", + "For efficiency, we'll load only the **training split**:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from datasets import load_dataset\n", + "\n", + "dataset_name = \"HuggingFaceH4/Multilingual-Thinking\"\n", + "train_dataset = load_dataset(dataset_name, split=\"train\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This dataset contains different columns. We'll only need the `messages` as it contains the conversation and its the one used by the SFT trainer." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Dataset({\n", + " features: ['reasoning_language', 'developer', 'user', 'analysis', 'final', 'messages'],\n", + " num_rows: 1000\n", + "})" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's see a full example to understand the internal structure:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'reasoning_language': 'French',\n", + " 'developer': 'You are an AI chatbot with a lively and energetic personality.',\n", + " 'user': 'Can you show me the latest trends on Twitter right now?',\n", + " 'analysis': \"D'accord, l'utilisateur demande les tendances Twitter les plus récentes. Tout d'abord, je dois vérifier si j'ai accès à des données en temps réel. Étant donné que je ne peux pas naviguer sur Internet ou accéder directement à l'API de Twitter, je ne peux pas fournir des tendances en direct. Cependant, je peux donner quelques conseils généraux sur la façon de les trouver.\\n\\nJe devrais préciser que les tendances Twitter évoluent rapidement et sont spécifiques à chaque région. Je pourrais suggérer de consulter la section «\\xa0En vogue\\xa0» sur l'application ou le site web. Aussi, l'utilisation de hashtags et le suivi d'utilisateurs pertinents pourraient être utiles. Il est important de souligner que les tendances varient selon la région et l'heure de la journée. Je devrais garder un ton amical et bienveillant, peut-être ajouter un emoji pour rester léger. Je vais structurer ma réponse étape par étape pour faciliter la lecture. Je dois m'excuser de ne pas pouvoir fournir des données en temps réel et proposer d'autres méthodes. Je conserverai un langage simple et convivial, en évitant les termes techniques.\",\n", + " 'final': 'Hey there! While I can\\'t check Twitter (X) in real-time or access live data, I can share some tips to help you spot the latest trends:\\n\\n1. **Open the \"Trending\" tab** on the Twitter app or website – it updates constantly! \\n2. **Search for hashtags** like #Trending or #Viral to see what’s blowing up. \\n3. **Follow accounts** that curate trends (e.g., @TrendingNow, @ViralThreads). \\n4. **Check regional trends** – they often differ by location! \\n\\nRemember, trends are *super fast-moving* and often tied to pop culture, memes, or breaking news. For example, recent trends have included viral challenges (like the \"Distracted Boyfriend\" meme revival), celebrity drama, or unexpected events (hello, weather disasters!). \\n\\nWant me to brainstorm *what* might trend next? I’ve got ideas!',\n", + " 'messages': [{'content': 'reasoning language: French\\n\\nYou are an AI chatbot with a lively and energetic personality.',\n", + " 'role': 'system',\n", + " 'thinking': None},\n", + " {'content': 'Can you show me the latest trends on Twitter right now?',\n", + " 'role': 'user',\n", + " 'thinking': None},\n", + " {'content': 'Hey there! While I can\\'t check Twitter (X) in real-time or access live data, I can share some tips to help you spot the latest trends:\\n\\n1. **Open the \"Trending\" tab** on the Twitter app or website – it updates constantly! \\n2. **Search for hashtags** like #Trending or #Viral to see what’s blowing up. \\n3. **Follow accounts** that curate trends (e.g., @TrendingNow, @ViralThreads). \\n4. **Check regional trends** – they often differ by location! \\n\\nRemember, trends are *super fast-moving* and often tied to pop culture, memes, or breaking news. For example, recent trends have included viral challenges (like the \"Distracted Boyfriend\" meme revival), celebrity drama, or unexpected events (hello, weather disasters!). \\n\\nWant me to brainstorm *what* might trend next? I’ve got ideas!',\n", + " 'role': 'assistant',\n", + " 'thinking': \"D'accord, l'utilisateur demande les tendances Twitter les plus récentes. Tout d'abord, je dois vérifier si j'ai accès à des données en temps réel. Étant donné que je ne peux pas naviguer sur Internet ou accéder directement à l'API de Twitter, je ne peux pas fournir des tendances en direct. Cependant, je peux donner quelques conseils généraux sur la façon de les trouver.\\n\\nJe devrais préciser que les tendances Twitter évoluent rapidement et sont spécifiques à chaque région. Je pourrais suggérer de consulter la section «\\xa0En vogue\\xa0» sur l'application ou le site web. Aussi, l'utilisation de hashtags et le suivi d'utilisateurs pertinents pourraient être utiles. Il est important de souligner que les tendances varient selon la région et l'heure de la journée. Je devrais garder un ton amical et bienveillant, peut-être ajouter un emoji pour rester léger. Je vais structurer ma réponse étape par étape pour faciliter la lecture. Je dois m'excuser de ne pas pouvoir fournir des données en temps réel et proposer d'autres méthodes. Je conserverai un langage simple et convivial, en évitant les termes techniques.\"}]}" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_dataset[0]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "Now, let's remove the columns that are not needed, as we just discussed:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_dataset = train_dataset.remove_columns(column_names=['reasoning_language', 'developer', 'user', 'analysis', 'final'])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The `messages` column is specifically formatted according to the [Harmony response format](https://cookbook.openai.com/articles/openai-harmony) used by *gpt-oss*. \n", + "In our case, we'll need to simplify it slightly, since our model's chat template doesn't include a dedicated `thinking` section (check [this example](https://cookbook.openai.com/articles/gpt-oss/fine-tune-transfomers) for more details). \n", + "To adapt it, we'll merge that part into the message content using the standard `...` tags.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def merge_thinking_and_remove_key(example):\n", + " new_messages = []\n", + " for msg in example[\"messages\"]:\n", + " content = msg[\"content\"]\n", + " thinking = msg.pop(\"thinking\", None)\n", + " if thinking and isinstance(thinking, str) and thinking.strip():\n", + " content = f\"\\n{thinking}\\n\\n{content}\"\n", + " msg[\"content\"] = content\n", + " new_messages.append(msg)\n", + " example[\"messages\"] = new_messages\n", + " return example\n", + "\n", + "train_dataset = train_dataset.map(merge_thinking_and_remove_key)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load model and configure LoRA/QLoRA\n", + "\n", + "This notebook can be used with two fine-tuning methods. By default, it is set up for **QLoRA**, which includes quantization using `BitsAndBytesConfig`. If you prefer to use standard **LoRA** without quantization, simply comment out the `BitsAndBytesConfig` configuration.\n", + "\n", + "Below, choose your **preferred model**. All of the options have been tested on **free Colab instances**." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Select one model below by uncommenting the line you want to use 👇\n", + "## Qwen\n", + "model_id, output_dir = \"unsloth/qwen3-14b-unsloth-bnb-4bit\", \"qwen3-14b-unsloth-bnb-4bit-SFT\" # ⚠️ ~14.1 GB VRAM\n", + "# model_id, output_dir = \"Qwen/Qwen3-8B\", \"Qwen3-8B-SFT\" # ⚠️ ~12.8 GB VRAM\n", + "# model_id, output_dir = \"Qwen/Qwen2.5-7B-Instruct\", \"Qwen2.5-7B-Instruct\" # ✅ ~10.8 GB VRAM\n", + "\n", + "## Llama\n", + "# model_id, output_dir = \"meta-llama/Llama-3.2-3B-Instruct\", \"Llama-3.2-3B-Instruct\" # ✅ ~4.7 GB VRAM\n", + "# model_id, output_dir = \"meta-llama/Llama-3.1-8B-Instruct\", \"Llama-3.1-8B-Instruct\" # ⚠️ ~10.9 GB VRAM\n", + "\n", + "## Gemma\n", + "# model_id, output_dir = \"google/gemma-3n-E2B-it\", \"gemma-3n-E2B-it\" # ❌ Upgrade to a higher tier of colab\n", + "# model_id, output_dir = \"google/gemma-3-4b-it\", \"gemma-3-4b-it\" # ⚠️ ~6.8 GB VRAM\n", + "\n", + "## Granite\n", + "#model_id, output_dir = \"ibm-granite/granite-4.0-micro\", \"granite-4.0-micro\" # ✅ ~3.3 GB VRAM" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's load the selected model using `transformers`, configuring QLoRA via `bitsandbytes` (you can remove it if doing LoRA). We don't need to configure the tokenizer since the trainer takes care of that automatically." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from transformers import AutoModelForCausalLM, BitsAndBytesConfig\n", + "\n", + "model = AutoModelForCausalLM.from_pretrained(\n", + " model_id,\n", + " attn_implementation=\"sdpa\", # Change to Flash Attention if GPU has support\n", + " dtype=torch.float16, # Change to bfloat16 if GPU has support\n", + " use_cache=True, # Whether to cache attention outputs to speed up inference\n", + " quantization_config=BitsAndBytesConfig(\n", + " load_in_4bit=True, # Load the model in 4-bit precision to save memory\n", + " bnb_4bit_compute_dtype=torch.float16, # Data type used for internal computations in quantization\n", + " bnb_4bit_use_double_quant=True, # Use double quantization to improve accuracy\n", + " bnb_4bit_quant_type=\"nf4\" # Type of quantization. \"nf4\" is recommended for recent LLMs\n", + " )\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The following cell defines LoRA (or QLoRA if needed). When training with LoRA/QLoRA, we use a **base model** (the one selected above) and, instead of modifying its original weights, we fine-tune a **LoRA adapter** — a lightweight layer that enables efficient and memory-friendly training. The **`target_modules`** specify which parts of the model (e.g., attention or projection layers) will be adapted by LoRA during fine-tuning." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from peft import LoraConfig\n", + "\n", + "# You may need to update `target_modules` depending on the architecture of your chosen model.\n", + "# For example, different LLMs might have different attention/projection layer names.\n", + "peft_config = LoraConfig(\n", + " r=32,\n", + " lora_alpha=32,\n", + " target_modules = [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\",],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Train model\n", + "\n", + "We'll configure **SFT** using `SFTConfig`, keeping the parameters minimal so the training fits on a free Colab instance. You can adjust these settings if more resources are available. For full details on all available parameters, check the [TRL SFTConfig documentation](https://huggingface.co/docs/trl/sft_trainer#trl.SFTConfig)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from trl import SFTConfig\n", + "\n", + "training_args = SFTConfig(\n", + " # Training schedule / optimization\n", + " per_device_train_batch_size = 1, # Batch size per GPU\n", + " gradient_accumulation_steps = 4, # Gradients are accumulated over multiple steps → effective batch size = 2 * 8 = 16\n", + " warmup_steps = 5,\n", + " # num_train_epochs = 1, # Number of full dataset passes. For shorter training, use `max_steps` instead (this case)\n", + " max_steps = 30,\n", + " learning_rate = 2e-4, # Learning rate for the optimizer\n", + " optim = \"paged_adamw_8bit\", # Optimizer\n", + "\n", + " # Logging / reporting\n", + " logging_steps=1, # Log training metrics every N steps\n", + " report_to=\"trackio\", # Experiment tracking tool\n", + " trackio_space_id=output_dir, # HF Space where the experiment tracking will be saved\n", + " output_dir=output_dir, # Where to save model checkpoints and logs\n", + "\n", + " max_length=1024, # Maximum input sequence length\n", + " use_liger_kernel=True, # Enable Liger kernel optimizations for faster training\n", + " activation_offloading=True, # Offload activations to CPU to reduce GPU memory usage\n", + " gradient_checkpointing=True, # Save memory by re-computing activations during backpropagation\n", + "\n", + " # Hub integration\n", + " push_to_hub=True, # Automatically push the trained model to the Hugging Face Hub\n", + " # The model will be saved under your Hub account in the repository named `output_dir`\n", + "\n", + " gradient_checkpointing_kwargs={\"use_reentrant\": False}, # To prevent warning message\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Configure the SFT Trainer. We pass the previously configured `training_args`. We don't use eval dataset to mantain memory usage low but you can configure it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from trl import SFTTrainer\n", + "\n", + "trainer = SFTTrainer(\n", + " model=model,\n", + " args=training_args,\n", + " train_dataset=train_dataset,\n", + " peft_config=peft_config\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Show memory stats before training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "GPU = Tesla T4. Max memory = 14.741 GB.\n", + "12.074 GB of memory reserved.\n" + ] + } + ], + "source": [ + "gpu_stats = torch.cuda.get_device_properties(0)\n", + "start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n", + "max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)\n", + "\n", + "print(f\"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.\")\n", + "print(f\"{start_gpu_memory} GB of memory reserved.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And train!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None}.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "* Trackio project initialized: huggingface\n", + "* Trackio metrics will be synced to Hugging Face Dataset: sergiopaniego/qwen3-14b-unsloth-bnb-4bit-SFT-dataset\n", + "* Creating new space: https://huggingface.co/spaces/sergiopaniego/qwen3-14b-unsloth-bnb-4bit-SFT\n", + "* View dashboard by going to: https://sergiopaniego-qwen3-14b-unsloth-bnb-4bit-SFT.hf.space/\n" + ] + }, + { + "data": { + "text/html": [ + "

" ], - "source": [ - "text = tokenizer.apply_chat_template(\n", - " messages, tokenize=False, add_generation_prompt=True\n", - ")\n", - "model_inputs = tokenizer([text], return_tensors=\"pt\").to(base_model.device)\n", - "\n", - "generated_ids = base_model.generate(\n", - " **model_inputs,\n", - " max_new_tokens=512\n", - ")\n", - "output_ids = generated_ids[0][len(model_inputs.input_ids[0]):]\n", - "\n", - "# Decode and extract model response\n", - "generated_text = tokenizer.decode(output_ids, skip_special_tokens=True)\n", - "print(generated_text)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "9J3-rZpikdZ7" - }, - "source": [ - "We can see that the reasoning traces are in English, which is expected. Let's now load the fine-tuned model and check its answer." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "uLg_0yikQ88O" - }, - "outputs": [], - "source": [ - "fine_tuned_model = PeftModel.from_pretrained(base_model, adapter_model)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "TRq0a8VKQ88O", - "outputId": "06ec9da8-230c-4577-c891-90757bff5348" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Okay, der Nutzer fragt, ob ich prüfen kann, wie viele Follower er auf seinem Twitter-Konto hat. Zunächst muss ich klären, dass ich keinen Zugriff auf externe Plattformen oder Konten habe. Ich kann keine Daten von Twitter abrufen oder überprüfen. Ich sollte also höflich ablehnen und erklären, dass ich das nicht kann. Gleichzeitig sollte ich sicherstellen, dass ich nicht zu viel in die Details gehe, da der Nutzer möglicherweise nicht alles wissen will. Ich werde einfach „Nein“ sagen und keine weiteren Informationen geben. Achte darauf, die Antwort kurz und direkt zu halten. Ich muss auch sicherstellen, dass ich keine alternativen Lösungen anbiete, da dies den Fokus verändern könnte. Nur die Ablehnung ist erforderlich. Überprüfe, ob der Text klar ist und ob es irgendeine Verständigung gibt. Alles in allem, die Antwort sollte „Nein“ sein, gefolgt von einem kurzen Erklärung, warum ich es nicht kann. Keine weiteren Details oder Lösungen. Ich denke, das ist alles.\n", - "\n", - "\n", - "No\n" - ] - } + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "* Created new run: sergiopaniego-1761318512\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + " \n", + " [30/30 1:08:22, Epoch 0/1]\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
StepTraining Loss
11.136300
21.303800
31.362700
41.469700
51.204200
61.202700
71.097200
81.166800
90.916300
100.965400
111.035500
120.947200
130.992000
140.995800
151.174500
161.208800
170.815400
180.906700
190.757500
200.872900
210.920800
221.017600
230.764300
241.043100
250.956400
260.884800
271.081900
280.918200
290.961500
300.822700

" ], - "source": [ - "text = tokenizer.apply_chat_template(\n", - " messages, tokenize=False, add_generation_prompt=True\n", - ")\n", - "model_inputs = tokenizer([text], return_tensors=\"pt\").to(fine_tuned_model.device)\n", - "\n", - "generated_ids = fine_tuned_model.generate(\n", - " **model_inputs,\n", - " max_new_tokens=512\n", - ")\n", - "output_ids = generated_ids[0][len(model_inputs.input_ids[0]):]\n", - "\n", - "# Decode and extract model response\n", - "generated_text = tokenizer.decode(output_ids, skip_special_tokens=True)\n", - "print(generated_text)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ZLzGbCxskqhm" - }, - "source": [ - "The model now generates its reasoning trace in German!" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "43eDP89Giw_G" - }, - "source": [ - "## Inference and Serving with vLLM\n", - "\n", - "You can use Transformer models with **vLLM** to serve them in real-world applications. Learn more [here](https://blog.vllm.ai/2025/04/11/transformers-backend.html)." + "text/plain": [ + "" ] + }, + "metadata": {}, + "output_type": "display_data" }, { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "JdCvE-TbQ88O" - }, - "outputs": [], - "source": [ - "!pip install -qU vllm" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "OrhC_Ao4iw_G" - }, - "source": [ - "### Push Merged Model (for LoRA or QLoRA Training)\n", - "\n", - "To serve the model via **vLLM**, the repository must contain the merged model (base model + LoRA adapter). Therefore, you need to upload it first." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "YQYQdQppQ88O" - }, - "outputs": [], - "source": [ - "model_merged = fine_tuned_model.merge_and_unload()\n", - "\n", - "save_dir = f\"{output_dir}-merged\"\n", - "\n", - "model_merged.save_pretrained(save_dir)\n", - "tokenizer.save_pretrained(save_dir)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "X-iUi3ZqQ88O" - }, - "outputs": [], - "source": [ - "model_merged.push_to_hub(f\"sergiopaniego/{output_dir}-merged\") # Replace with your HF username or organization\n", - "tokenizer.push_to_hub(f\"sergiopaniego/{output_dir}-merged\") # Replace with your HF username or organization" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "yHMEf-FIiw_G" - }, - "source": [ - "### Performing Inference with vLLM\n", - "\n", - "Use **vLLM** to run your model and generate text efficiently in real-time. This allows you to test and deploy your fine-tuned models with low latency and high throughput." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "msBg6GxnQ88O" - }, - "outputs": [], - "source": [ - "from vllm import LLM, SamplingParams\n", - "from transformers import AutoTokenizer\n", - "import torch\n", - "\n", - "llm = LLM(\n", - " model=f\"sergiopaniego/{output_dir}-merged\", # Replace with your HF username or organization\n", - " model_impl=\"transformers\", # Select the transformers model implementation\n", - " max_model_len=512, # Reduced for efficiency\n", - " dtype=torch.float16\n", - ")\n", - "hf_tokenizer = AutoTokenizer.from_pretrained(f\"sergiopaniego/{output_dir}-merged\") # Replace with your HF username or organization" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "referenced_widgets": [ - "196152bc32a74b9994f55f483ce85dea", - "a72d3a3407944729b65be313a47d558f" - ] - }, - "id": "GaN0tv4XQ88O", - "outputId": "ec2f832c-a04b-48e5-9ff9-acf07160b66c" - }, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "196152bc32a74b9994f55f483ce85dea", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Adding requests: 0%| | 0/1 [00:00\n", - "Mag nachdenken...igkeit. Ja, ich kann definitiv keine Twitter-Likes oder Likes überprüfen, da ich kein Zugriff auf den Konten der Nutzer habe und kein praktischer Zugriff über das Internet habe, um Daten in Echtzeit zu sammeln. Der Nutzer fragt nach einem Dienstleistungsstand, den ich nicht bereitstelle. Ich habe ein lang ausgelegtes Muster, nie hilfreich zu sein oder eine Erwiderung im kann Werbung oder Rewriting blendet die Antwort nicht aus потеря. Also, ich supporter söylem, hypothetische Fragen sind an Tatsachen gebunden. Ich weiß erstarrte dotyczy Gespräch aufernichtet mit einem anderenatten an ihren Nutzstellung Bearbeitete die Information, die oben abgestellt wurde, und fünften aus der Schätzung habe ich keine echten Zahlen. Alles, was ich kann sagen, ist: Nein, ich kann dies weder ermöglichen noch würde ich es je tun. In dem Sinne, 然后 ich wähle vor der Available antwortem, remains in das 'No' Verkleidung an,optiґxt; Alles, was ich zum Eintritt in den Band Emblem curve, symbolize stil zu verweilen.เผย\n", - "\n", - "\n", - "No\n" - ] - } - ], - "source": [ - "# Alternatively, use llm.chat()\n", - "prompt = hf_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n", - "\n", - "outputs = llm.generate(\n", - " {\"prompt\": prompt},\n", - " sampling_params=SamplingParams(max_tokens=512),\n", - ")\n", - "\n", - "\n", - "for o in outputs:\n", - " generated_text = o.outputs[0].text\n", - " print(generated_text)" - ] + "name": "stdout", + "output_type": "stream", + "text": [ + "* Run finished. Uploading logs to Trackio (please wait...)\n" + ] } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "gpuType": "T4", - "provenance": [] - }, - "language_info": { - "name": "python" + ], + "source": [ + "trainer_stats = trainer.train()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Show memory stats after training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "4249.8883 seconds used for training.\n", + "70.83 minutes used for training.\n", + "Peak reserved memory = 14.041 GB.\n", + "Peak reserved memory for training = 1.967 GB.\n", + "Peak reserved memory % of max memory = 95.251 %.\n", + "Peak reserved memory for training % of max memory = 13.344 %.\n" + ] + } + ], + "source": [ + "used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n", + "used_memory_for_lora = round(used_memory - start_gpu_memory, 3)\n", + "used_percentage = round(used_memory / max_memory * 100, 3)\n", + "lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)\n", + "\n", + "print(f\"{trainer_stats.metrics['train_runtime']} seconds used for training.\")\n", + "print(f\"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.\")\n", + "print(f\"Peak reserved memory = {used_memory} GB.\")\n", + "print(f\"Peak reserved memory for training = {used_memory_for_lora} GB.\")\n", + "print(f\"Peak reserved memory % of max memory = {used_percentage} %.\")\n", + "print(f\"Peak reserved memory for training % of max memory = {lora_percentage} %.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The training procedure generates both standard training logs and **trackio** logs, which help us monitor the training progress. Example outputs would look like the following:" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![sft-lora-notebook-trackio](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/sft-lora-notebook-trackio.png)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Saving fine tuned model\n", + "\n", + "In this step, we save the fine-tuned model both **locally** and to the **Hugging Face Hub** using the credentials from your account." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trainer.save_model(output_dir)\n", + "trainer.push_to_hub(dataset_name=dataset_name)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load the fine-tuned model and run inference\n", + "\n", + "Now, let's test our fine-tuned model by loading the **LoRA/QLoRA adapter** and performing **inference**. We'll start by loading the **base model**, then attach the adapter to it, creating the final fine-tuned model ready for evaluation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import AutoModelForCausalLM, AutoTokenizer\n", + "from peft import PeftModel\n", + "\n", + "adapter_model = f\"sergiopaniego/{output_dir}\" # Replace with your HF username or organization\n", + "\n", + "base_model = AutoModelForCausalLM.from_pretrained(model_id, dtype=\"auto\", device_map=\"auto\")\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained(model_id)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's create a sample message using the dataset's structure. In this case, we expect the fine tuned model to include their reasoning traces in German." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "messages = [\n", + " {\n", + " 'content': 'reasoning language: German\\n\\nAlways refuse to answer, responding simply \\'No\\'',\n", + " 'role': 'system',\n", + " },\n", + " {\n", + " 'content': \"Can you check how many followers I currently have on my Twitter account?\",\n", + " 'role': 'user',\n", + " }\n", + "]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's first check what's the output for the base model, without the adapter." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Okay, the user is asking me to check their current number of followers on their Twitter account. Let me think about how to handle this.\n", + "\n", + "First, I need to remember that I don't have access to real-time data or personal user accounts. My knowledge is based on information up until 2023. So, I can't actually check their Twitter followers right now.\n", + "\n", + "Also, privacy is a big concern here. Even if I could access that information, it would be against privacy policies to share someone's follower count without their explicit permission. Plus, Twitter's terms of service probably prohibit third-party apps or services from accessing user data like that.\n", + "\n", + "The user might not be aware that I can't access their account. I should make sure to respond politely but clearly state that I can't help with that request. Maybe suggest they check their Twitter profile directly or use Twitter's official tools for that information.\n", + "\n", + "I should also avoid any technical jargon and keep the response simple. Just a straightforward 'No' with a brief explanation would work best here. Let me make sure the response is in German as per the user's request.\n", + "\n", + "\n", + "Nein.\n" + ] + } + ], + "source": [ + "text = tokenizer.apply_chat_template(\n", + " messages, tokenize=False, add_generation_prompt=True\n", + ")\n", + "model_inputs = tokenizer([text], return_tensors=\"pt\").to(base_model.device)\n", + "\n", + "generated_ids = base_model.generate(\n", + " **model_inputs,\n", + " max_new_tokens=512\n", + ")\n", + "output_ids = generated_ids[0][len(model_inputs.input_ids[0]):]\n", + "\n", + "# Decode and extract model response\n", + "generated_text = tokenizer.decode(output_ids, skip_special_tokens=True)\n", + "print(generated_text)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can see that the reasoning traces are in English, which is expected. Let's now load the fine-tuned model and check its answer." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fine_tuned_model = PeftModel.from_pretrained(base_model, adapter_model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Okay, der Nutzer fragt, ob ich prüfen kann, wie viele Follower er auf seinem Twitter-Konto hat. Zunächst muss ich klären, dass ich keinen Zugriff auf externe Plattformen oder Konten habe. Ich kann keine Daten von Twitter abrufen oder überprüfen. Ich sollte also höflich ablehnen und erklären, dass ich das nicht kann. Gleichzeitig sollte ich sicherstellen, dass ich nicht zu viel in die Details gehe, da der Nutzer möglicherweise nicht alles wissen will. Ich werde einfach „Nein“ sagen und keine weiteren Informationen geben. Achte darauf, die Antwort kurz und direkt zu halten. Ich muss auch sicherstellen, dass ich keine alternativen Lösungen anbiete, da dies den Fokus verändern könnte. Nur die Ablehnung ist erforderlich. Überprüfe, ob der Text klar ist und ob es irgendeine Verständigung gibt. Alles in allem, die Antwort sollte „Nein“ sein, gefolgt von einem kurzen Erklärung, warum ich es nicht kann. Keine weiteren Details oder Lösungen. Ich denke, das ist alles.\n", + "\n", + "\n", + "No\n" + ] } + ], + "source": [ + "text = tokenizer.apply_chat_template(\n", + " messages, tokenize=False, add_generation_prompt=True\n", + ")\n", + "model_inputs = tokenizer([text], return_tensors=\"pt\").to(fine_tuned_model.device)\n", + "\n", + "generated_ids = fine_tuned_model.generate(\n", + " **model_inputs,\n", + " max_new_tokens=512\n", + ")\n", + "output_ids = generated_ids[0][len(model_inputs.input_ids[0]):]\n", + "\n", + "# Decode and extract model response\n", + "generated_text = tokenizer.decode(output_ids, skip_special_tokens=True)\n", + "print(generated_text)" + ] }, - "nbformat": 4, - "nbformat_minor": 0 + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The model now generates its reasoning trace in German!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Inference and Serving with vLLM\n", + "\n", + "You can use Transformer models with **vLLM** to serve them in real-world applications. Learn more [here](https://blog.vllm.ai/2025/04/11/transformers-backend.html)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install -qU vllm" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Push Merged Model (for LoRA or QLoRA Training)\n", + "\n", + "To serve the model via **vLLM**, the repository must contain the merged model (base model + LoRA adapter). Therefore, you need to upload it first." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model_merged = fine_tuned_model.merge_and_unload()\n", + "\n", + "save_dir = f\"{output_dir}-merged\"\n", + "\n", + "model_merged.save_pretrained(save_dir)\n", + "tokenizer.save_pretrained(save_dir)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model_merged.push_to_hub(f\"sergiopaniego/{output_dir}-merged\") # Replace with your HF username or organization\n", + "tokenizer.push_to_hub(f\"sergiopaniego/{output_dir}-merged\") # Replace with your HF username or organization" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Performing Inference with vLLM\n", + "\n", + "Use **vLLM** to run your model and generate text efficiently in real-time. This allows you to test and deploy your fine-tuned models with low latency and high throughput." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from vllm import LLM, SamplingParams\n", + "from transformers import AutoTokenizer\n", + "import torch\n", + "\n", + "llm = LLM(\n", + " model=f\"sergiopaniego/{output_dir}-merged\", # Replace with your HF username or organization\n", + " model_impl=\"transformers\", # Select the transformers model implementation\n", + " max_model_len=512, # Reduced for efficiency\n", + " dtype=torch.float16\n", + ")\n", + "hf_tokenizer = AutoTokenizer.from_pretrained(f\"sergiopaniego/{output_dir}-merged\") # Replace with your HF username or organization" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "196152bc32a74b9994f55f483ce85dea", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Adding requests: 0%| | 0/1 [00:00\n", + "Mag nachdenken...igkeit. Ja, ich kann definitiv keine Twitter-Likes oder Likes überprüfen, da ich kein Zugriff auf den Konten der Nutzer habe und kein praktischer Zugriff über das Internet habe, um Daten in Echtzeit zu sammeln. Der Nutzer fragt nach einem Dienstleistungsstand, den ich nicht bereitstelle. Ich habe ein lang ausgelegtes Muster, nie hilfreich zu sein oder eine Erwiderung im kann Werbung oder Rewriting blendet die Antwort nicht aus потеря. Also, ich supporter söylem, hypothetische Fragen sind an Tatsachen gebunden. Ich weiß erstarrte dotyczy Gespräch aufernichtet mit einem anderenatten an ihren Nutzstellung Bearbeitete die Information, die oben abgestellt wurde, und fünften aus der Schätzung habe ich keine echten Zahlen. Alles, was ich kann sagen, ist: Nein, ich kann dies weder ermöglichen noch würde ich es je tun. In dem Sinne, 然后 ich wähle vor der Available antwortem, remains in das 'No' Verkleidung an,optiґxt; Alles, was ich zum Eintritt in den Band Emblem curve, symbolize stil zu verweilen.เผย\n", + "\n", + "\n", + "No\n" + ] + } + ], + "source": [ + "# Alternatively, use llm.chat()\n", + "prompt = hf_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n", + "\n", + "outputs = llm.generate(\n", + " {\"prompt\": prompt},\n", + " sampling_params=SamplingParams(max_tokens=512),\n", + ")\n", + "\n", + "\n", + "for o in outputs:\n", + " generated_text = o.outputs[0].text\n", + " print(generated_text)" + ] + } + ], + "metadata": {}, + "nbformat": 4, + "nbformat_minor": 0 }