Skip to content

Feature: Port and Optimize Unsloth Notebooks for Kaggle Dual T4 GPUs (T4x2)#212

Open
AAB20 wants to merge 3 commits intounslothai:mainfrom
AAB20:main
Open

Feature: Port and Optimize Unsloth Notebooks for Kaggle Dual T4 GPUs (T4x2)#212
AAB20 wants to merge 3 commits intounslothai:mainfrom
AAB20:main

Conversation

@AAB20
Copy link
Copy Markdown

@AAB20 AAB20 commented Mar 17, 2026

Overview

This Pull Request introduces a comprehensive adaptation of the official Unsloth AI notebooks. The original work, which was designed primarily for Google Colab environments, has been fully refactored and optimized to run seamlessly on Kaggle's platform.

Key Modifications & Enhancements

  • Environment Migration: Transitioned the core setup and library installations from Google Colab to Kaggle, resolving dependency conflicts and adjusting system paths.
  • Hardware Optimization (T4x2): Specifically engineered the configuration to fully utilize Kaggle's dual T4 GPU instances (T4x2). This setup maximizes memory distribution and compute efficiency, allowing for faster and more stable training.
  • Memory Management: Implemented necessary adjustments to handle the specific VRAM constraints and capabilities of the dual T4 architecture.

Supported Tasks & Models

This updated directory now includes fully functional notebooks covering a wide spectrum of advanced AI development, including:

  • Large Language Model (LLM) Fine-tuning
  • Reinforcement Learning (RL) implementations
  • Text, Vision, and Audio Embedding models
  • Text-to-Speech (TTS) model optimizations

Motivation

By porting these crucial resources to Kaggle, this project aims to democratize access to high-performance AI fine-tuning. It provides developers and researchers a ready-to-use template to leverage Kaggle's free multi-GPU compute, significantly expanding the practical utility of the Unsloth ecosystem.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request ports and optimizes a collection of Unsloth AI notebooks for Kaggle's dual T4 GPU environment. The changes enable efficient fine-tuning and inference for various advanced AI models, including LLMs, RL, and multimodal embeddings, making high-performance AI development more accessible to researchers and developers on the Kaggle platform.

Highlights

  • Kaggle T4x2 Optimization: Refactored and optimized official Unsloth AI notebooks to run seamlessly and efficiently on Kaggle's dual T4 GPU instances, maximizing memory distribution and compute efficiency.
  • Expanded Model Support: Introduced new notebooks covering a wide range of advanced AI tasks, including Large Language Model (LLM) fine-tuning, Reinforcement Learning (RL), and Text, Vision, and Audio Embedding models.
  • Enhanced Accessibility: Aims to democratize access to high-performance AI fine-tuning by providing ready-to-use templates leveraging Kaggle's free multi-GPU compute, expanding the practical utility of the Unsloth ecosystem.
  • Specific Notebook Additions: Added new notebooks for Llama 3.1 GRPO LoRA, conversational CodeGemma, Deepseek OCR 2, EmbeddingGemma, Falcon H1 Alpaca, GLM Flash, GPT-OSS BNB inference, and Gemma 3N multimodal tasks.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a valuable set of notebooks for various models on Kaggle's T4x2 environment. The optimizations for this specific hardware are a great addition.

My main feedback is regarding the significant amount of duplicated code across the notebooks. Many helper functions (like load_with_retry), configuration blocks, and setup cells are copied in multiple files. This will make future maintenance difficult, as any change will need to be manually propagated to all notebooks. I strongly recommend refactoring this common code into a shared Python utility script that can be imported into each notebook. This would greatly improve the maintainability and readability of the project.

I've left a few specific comments on instances of code duplication and other potential improvements for correctness and maintainability.

"execution_count": null,
"metadata": {},
"outputs": [],
"source": "from datasets import load_dataset, load_from_disk\nimport os, time\n\nDATASET_CACHE = \"/kaggle/working/_dataset_cache\"\n\ndef load_with_retry(path, split=\"train\", retries=5, wait=15):\n if os.path.exists(DATASET_CACHE):\n print(f\"Loading from cache: {DATASET_CACHE}\")\n return load_from_disk(DATASET_CACHE)\n for i in range(retries):\n try:\n print(f\"Attempt {i+1}/{retries}: {path}\")\n ds = load_dataset(path, split=split)\n ds.save_to_disk(DATASET_CACHE)\n print(f\"Loaded OK — {len(ds)} rows\")\n return ds\n except Exception as e:\n print(f\" Failed: {e}\")\n if i < retries - 1:\n time.sleep(wait)\n raise RuntimeError(f\"Failed to load {path} after {retries} attempts\")"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This load_with_retry function is duplicated across multiple notebooks. To improve maintainability, consider moving this and other shared utility functions into a common Python script that can be imported by all notebooks.

Additionally, the caching logic is not robust. It uses a fixed DATASET_CACHE path regardless of the dataset path argument. This can lead to cache collisions if different datasets are used across notebooks in the same environment. A better approach would be to create a unique cache directory for each dataset, for example by using a hash of the dataset path.

Example for a more robust caching mechanism:

import hashlib

def get_cache_path(dataset_path):
    hash_object = hashlib.md5(dataset_path.encode())
    return f"/kaggle/working/_dataset_cache_{hash_object.hexdigest()}"

# In load_with_retry:
dataset_cache_path = get_cache_path(path)
if os.path.exists(dataset_cache_path):
    # ... load from disk
...
ds.save_to_disk(dataset_cache_path)

"id": "GjBFrttr-y1_"
},
"outputs": [],
"source": "global PRINTED_TIMES\nPRINTED_TIMES = 0\nglobal PRINT_EVERY_STEPS\nPRINT_EVERY_STEPS = 5\n\ndef check_numbers(prompts, completions, answer, **kwargs):\n question = prompts[0][-1][\"content\"]\n responses = [completion[0][\"content\"] for completion in completions]\n\n extracted_responses = [\n guess.group(1)\n if (guess := match_numbers.search(r)) is not None else None \\\n for r in responses\n ]\n\n scores = []\n # Print only every few steps\n global PRINTED_TIMES\n global PRINT_EVERY_STEPS\n if PRINTED_TIMES % PRINT_EVERY_STEPS == 0:\n print('*'*20, f\"Question:\\n{question}\", f\"\\nAnswer:\\n{answer[0]}\", f\"\\nResponse:\\n{responses[0]}\", f\"\\nExtracted:\\n{extracted_responses[0]}\")\n PRINTED_TIMES += 1\n\n for guess, true_answer in zip(extracted_responses, answer):\n if guess is None:\n scores.append(0)\n continue\n # Convert to numbers\n try:\n true_answer = float(true_answer.strip())\n # Remove commas like in 123,456\n guess = float(guess.strip().replace(\",\", \"\"))\n scores.append(1.5 if guess == true_answer else -0.5)\n except:\n scores.append(0)\n continue\n return scores"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The use of global variables (PRINTED_TIMES, PRINT_EVERY_STEPS) makes this function stateful and harder to test and reuse. This can lead to unexpected behavior if the function is called multiple times or from different places.

A better approach would be to manage this state within a class or by using a closure. For example, you could wrap the reward function in a class:

class RewardLogger:
    def __init__(self, print_every_steps=5):
        self.printed_times = 0
        self.print_every_steps = print_every_steps

    def check_numbers(self, prompts, completions, answer, **kwargs):
        # ... (rest of the function logic)
        if self.printed_times % self.print_every_steps == 0:
            # ... print statement
        self.printed_times += 1
        # ...
        return scores

reward_logger = RewardLogger()
# in GRPOTrainer:
# reward_funcs = [..., reward_logger.check_numbers]

This encapsulates the state and makes the code cleaner and more maintainable.

"execution_count": null,
"metadata": {},
"outputs": [],
"source": "import gc, torch\n\n# ── Patch trainer before training ─────────────────────────\ntry:\n trainer.args.per_device_train_batch_size = 1\n trainer.args.gradient_accumulation_steps = 32\n trainer.args.gradient_checkpointing = False\n trainer.args.fp16 = True\n trainer.args.bf16 = False\n trainer.args.dataloader_pin_memory = True\n trainer.args.dataloader_num_workers = 2\n trainer.args.ddp_find_unused_parameters = False\n trainer.args.max_grad_norm = 0.3\n eff = trainer.args.per_device_train_batch_size * trainer.args.gradient_accumulation_steps\n print(f\"TrainingArguments patched — effective batch: {eff}\")\nexcept Exception as e:\n print(f\"TrainingArguments skipped: {e}\")\n\ntry:\n _m = trainer.model\n while hasattr(_m, \"module\"):\n _m = _m.module\n if hasattr(_m, \"enable_input_require_grads\"):\n _m.enable_input_require_grads()\n print(\"enable_input_require_grads OK\")\nexcept Exception as e:\n print(f\"enable_input_require_grads skipped: {e}\")\n\ngc.collect()\ntorch.cuda.empty_cache()\nfor i in range(torch.cuda.device_count()):\n print(f\" GPU {i} free: {round(torch.cuda.mem_get_info(i)[0]/1024**3,2)} GiB\")\nprint(\"Ready to train\")"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Patching trainer.args after the GRPOTrainer has been initialized can be confusing and error-prone. The arguments are used during the trainer's __init__. Modifying them afterwards might not have the intended effect for all parameters.

It would be clearer and safer to configure all training arguments directly in the GRPOConfig object before passing it to the trainer. This makes the configuration explicit and centralized.

For example, move the patched arguments to the GRPOConfig in the preceding cell:

training_args = GRPOConfig(
    learning_rate = 5e-6,
    # ... other args
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = 32,
    gradient_checkpointing = False,
    fp16 = True,
    bf16 = False,
    dataloader_pin_memory = True,
    dataloader_num_workers = 2,
    ddp_find_unused_parameters = False,
    max_grad_norm = 0.3,
    # ...
)

"execution_count": null,
"metadata": {},
"outputs": [],
"source": "\nimport os, re\nif \"COLAB_\" not in \"\".join(os.environ.keys()):\n !pip install -q unsloth # Do this in local & cloud setups\nelse:\n import torch; v = re.match(r'[\\d]{1,}\\.[\\d]{1,}', str(torch.__version__)).group(0)\n xformers = 'xformers==' + {'2.10':'0.0.34','2.9':'0.0.33.post1','2.8':'0.0.32.post2'}.get(v, \"0.0.34\")\n !pip install -q sentencepiece protobuf \"datasets==4.3.0\" \"huggingface_hub>=0.34.0\" hf_transfer\n !pip install -q --no-deps unsloth_zoo bitsandbytes accelerate {xformers} peft trl triton unsloth\n!pip install -q transformers==4.56.2\n!pip install -q --no-deps trl==0.22.2"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The regex used to parse the torch version is fragile and might fail on development or more complex version strings (e.g., 2.1.0.dev20230310+cu118). A more robust way to get the major.minor version would be to use string splitting, which is also more readable.

For example:

import torch
v = '.'.join(torch.__version__.split('.')[:2])

"id": "p31Z-S6FUieB"
},
"outputs": [],
"source": "unsloth_template = \\\n \"{{ bos_token }}\"\\\n \"{{ 'You are a helpful assistant to the user\\n' }}\"\\\n \"{% endif %}\"\\\n \"{% for message in messages %}\"\\\n \"{% if message['role'] == 'user' %}\"\\\n \"{{ '>>> User: ' + message['content'] + '\\n' }}\"\\\n \"{% elif message['role'] == 'assistant' %}\"\\\n \"{{ '>>> Assistant: ' + message['content'] + eos_token + '\\n' }}\"\\\n \"{% endif %}\"\\\n \"{% endfor %}\"\\\n \"{% if add_generation_prompt %}\"\\\n \"{{ '>>> Assistant: ' }}\"\\\n \"{% endif %}\"\nunsloth_eos_token = \"eos_token\"\n\nif False:\n tokenizer = get_chat_template(\n tokenizer,\n chat_template = (unsloth_template, unsloth_eos_token,), # You must provide a template and EOS token\n mapping = {\"role\" : \"from\", \"content\" : \"value\", \"user\" : \"human\", \"assistant\" : \"gpt\"}, # ShareGPT style\n map_eos_token = True, # Maps <|im_end|> to </s> instead\n )"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This Jinja template has a syntax error. The {% endif %} on the third line of the template string does not have a corresponding {% if ... %} block. This will raise a TemplateSyntaxError if this code block is executed. Although it's currently under an if False block, it's best to fix it for future use. It seems to be a copy-paste error and should likely be removed.

"id": "E2WR-p20LcG_"
},
"outputs": [],
"source": "# @title Create datacollator\n\nimport torch\nimport math\nfrom dataclasses import dataclass\nfrom typing import Dict, List, Any, Tuple\nfrom PIL import Image, ImageOps\nfrom torch.nn.utils.rnn import pad_sequence\nimport io\n\nfrom deepseek_ocr2.modeling_deepseekocr2 import (\n format_messages,\n text_encode,\n BasicImageTransform,\n dynamic_preprocess,\n)\n\n@dataclass\nclass DeepSeekOCR2DataCollator:\n \"\"\"\n Args:\n tokenizer: Tokenizer\n model: Model\n image_size: Size for image patches (default: 768)\n base_size: Size for global view (default: 1024)\n crop_mode: Whether to use dynamic cropping for large images\n train_on_responses_only: If True, only train on assistant responses (mask user prompts)\n \"\"\"\n tokenizer: Any\n model: Any\n image_size: int = 768\n base_size: int = 1024\n crop_mode: bool = True\n image_token_id: int = 128815\n train_on_responses_only: bool = True\n\n def __init__(\n self,\n tokenizer,\n model,\n image_size: int = 768,\n base_size: int = 1024,\n crop_mode: bool = True,\n train_on_responses_only: bool = True,\n ):\n self.tokenizer = tokenizer\n self.model = model\n self.image_size = image_size\n self.base_size = base_size\n self.crop_mode = crop_mode\n self.image_token_id = 128815\n self.dtype = model.dtype # Get dtype from model\n self.train_on_responses_only = train_on_responses_only\n\n self.image_transform = BasicImageTransform(\n mean = (0.5, 0.5, 0.5),\n std = (0.5, 0.5, 0.5),\n normalize = True\n )\n self.patch_size = 16\n self.downsample_ratio = 4\n\n # Get BOS token ID from tokenizer\n if hasattr(tokenizer, 'bos_token_id') and tokenizer.bos_token_id is not None:\n self.bos_id = tokenizer.bos_token_id\n else:\n self.bos_id = 0\n print(f\"Warning: tokenizer has no bos_token_id, using default: {self.bos_id}\")\n\n def deserialize_image(self, image_data) -> Image.Image:\n \"\"\"Convert image data (bytes dict or PIL Image) to PIL Image in RGB mode\"\"\"\n if isinstance(image_data, Image.Image):\n return image_data.convert(\"RGB\")\n elif isinstance(image_data, dict) and 'bytes' in image_data:\n image_bytes = image_data['bytes']\n image = Image.open(io.BytesIO(image_bytes))\n return image.convert(\"RGB\")\n else:\n raise ValueError(f\"Unsupported image format: {type(image_data)}\")\n\n def calculate_image_token_count(self, image: Image.Image, crop_ratio: Tuple[int, int]) -> int:\n \"\"\"Calculate the number of tokens this image will generate\"\"\"\n num_queries = math.ceil((self.image_size // self.patch_size) / self.downsample_ratio)\n num_queries_base = math.ceil((self.base_size // self.patch_size) / self.downsample_ratio)\n\n width_crop_num, height_crop_num = crop_ratio\n\n if self.crop_mode:\n img_tokens = num_queries_base * num_queries_base + 1\n if width_crop_num > 1 or height_crop_num > 1:\n img_tokens += (num_queries * width_crop_num) * (num_queries * height_crop_num)\n else:\n img_tokens = num_queries * num_queries + 1\n\n return img_tokens\n\n def process_image(self, image: Image.Image) -> Tuple[List, List, List, List, Tuple[int, int]]:\n \"\"\"\n Process a single image based on crop_mode and size thresholds\n\n Returns:\n Tuple of (images_list, images_crop_list, images_spatial_crop, tokenized_image, crop_ratio)\n \"\"\"\n images_list = []\n images_crop_list = []\n images_spatial_crop = []\n\n if self.crop_mode:\n # Determine crop ratio based on image size\n if image.size[0] <= 768 and image.size[1] <= 768:\n crop_ratio = (1, 1)\n images_crop_raw = []\n else:\n images_crop_raw, crop_ratio = dynamic_preprocess(\n image, min_num = 2, max_num = 6,\n image_size = self.image_size, use_thumbnail = False\n )\n\n # Process global view with padding\n global_view = ImageOps.pad(\n image, (self.base_size, self.base_size),\n color = tuple(int(x * 255) for x in self.image_transform.mean)\n )\n images_list.append(self.image_transform(global_view).to(self.dtype))\n\n width_crop_num, height_crop_num = crop_ratio\n images_spatial_crop.append([width_crop_num, height_crop_num])\n\n # Process local views (crops) if applicable\n if width_crop_num > 1 or height_crop_num > 1:\n for crop_img in images_crop_raw:\n images_crop_list.append(\n self.image_transform(crop_img).to(self.dtype)\n )\n\n # Calculate image tokens\n num_queries = math.ceil((self.image_size // self.patch_size) / self.downsample_ratio)\n num_queries_base = math.ceil((self.base_size // self.patch_size) / self.downsample_ratio)\n\n tokenized_image = ([self.image_token_id] * num_queries_base) * num_queries_base\n tokenized_image += [self.image_token_id]\n\n if width_crop_num > 1 or height_crop_num > 1:\n tokenized_image += ([self.image_token_id] * (num_queries * width_crop_num)) * (\n num_queries * height_crop_num)\n\n else: # crop_mode = False\n crop_ratio = (1, 1)\n images_spatial_crop.append([1, 1])\n\n # For smaller base sizes, resize; for larger, pad\n if self.base_size <= 768:\n resized_image = image.resize((self.base_size, self.base_size), Image.LANCZOS)\n images_list.append(self.image_transform(resized_image).to(self.dtype))\n else:\n global_view = ImageOps.pad(\n image, (self.base_size, self.base_size),\n color = tuple(int(x * 255) for x in self.image_transform.mean)\n )\n images_list.append(self.image_transform(global_view).to(self.dtype))\n\n num_queries = math.ceil((self.base_size // self.patch_size) / self.downsample_ratio)\n tokenized_image = ([self.image_token_id] * num_queries) * num_queries\n tokenized_image += [self.image_token_id]\n\n return images_list, images_crop_list, images_spatial_crop, tokenized_image, crop_ratio\n\n def process_single_sample(self, messages: List[Dict]) -> Dict[str, Any]:\n \"\"\"\n Process a single conversation into model inputs.\n \"\"\"\n\n # --- 1. Setup ---\n images = []\n for message in messages:\n if \"images\" in message and message[\"images\"]:\n for img_data in message[\"images\"]:\n if img_data is not None:\n pil_image = self.deserialize_image(img_data)\n images.append(pil_image)\n\n if not images:\n raise ValueError(\"No images found in sample. Please ensure all samples contain images.\")\n\n tokenized_str = []\n images_seq_mask = []\n images_list, images_crop_list, images_spatial_crop = [], [], []\n\n prompt_token_count = -1 # Index to start training\n assistant_started = False\n image_idx = 0\n\n # Add BOS token at the very beginning\n tokenized_str.append(self.bos_id)\n images_seq_mask.append(False)\n\n for message in messages:\n role = message[\"role\"]\n content = message[\"content\"]\n\n # Check if this is the assistant's turn\n if role == \"<|Assistant|>\":\n if not assistant_started:\n # This is the split point. All tokens added *so far*\n # are part of the prompt.\n prompt_token_count = len(tokenized_str)\n assistant_started = True\n\n # Append the EOS token string to the *end* of assistant content\n content = f\"{content.strip()} {self.tokenizer.eos_token}\"\n\n # Split this message's content by the image token\n text_splits = content.split('<image>')\n\n for i, text_sep in enumerate(text_splits):\n # Tokenize the text part\n tokenized_sep = text_encode(self.tokenizer, text_sep, bos = False, eos = False)\n tokenized_str.extend(tokenized_sep)\n images_seq_mask.extend([False] * len(tokenized_sep))\n\n # If this text is followed by an <image> tag\n if i < len(text_splits) - 1:\n if image_idx >= len(images):\n raise ValueError(\n f\"Data mismatch: Found '<image>' token but no corresponding image.\"\n )\n\n # Process the image\n image = images[image_idx]\n img_list, crop_list, spatial_crop, tok_img, _ = self.process_image(image)\n\n images_list.extend(img_list)\n images_crop_list.extend(crop_list)\n images_spatial_crop.extend(spatial_crop)\n\n # Add image placeholder tokens\n tokenized_str.extend(tok_img)\n images_seq_mask.extend([True] * len(tok_img))\n\n image_idx += 1 # Move to the next image\n\n # --- 3. Validation and Final Prep ---\n if image_idx != len(images):\n raise ValueError(\n f\"Data mismatch: Found {len(images)} images but only {image_idx} '<image>' tokens were used.\"\n )\n\n # If we never found an assistant message, we're in a weird state\n # (e.g., user-only prompt). We mask everything.\n if not assistant_started:\n print(\"Warning: No assistant message found in sample. Masking all tokens.\")\n prompt_token_count = len(tokenized_str)\n\n # Prepare image tensors\n images_ori = torch.stack(images_list, dim = 0)\n images_spatial_crop_tensor = torch.tensor(images_spatial_crop, dtype = torch.long)\n\n if images_crop_list:\n images_crop = torch.stack(images_crop_list, dim = 0)\n else:\n images_crop = torch.zeros((1, 3, self.base_size, self.base_size), dtype = self.dtype)\n\n return {\n \"input_ids\": torch.tensor(tokenized_str, dtype = torch.long),\n \"images_seq_mask\": torch.tensor(images_seq_mask, dtype = torch.bool),\n \"images_ori\": images_ori,\n \"images_crop\": images_crop,\n \"images_spatial_crop\": images_spatial_crop_tensor,\n \"prompt_token_count\": prompt_token_count, # This is now accurate\n }\n\n def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:\n \"\"\"Collate batch of samples\"\"\"\n batch_data = []\n\n # Process each sample\n for feature in features:\n try:\n processed = self.process_single_sample(feature['messages'])\n batch_data.append(processed)\n except Exception as e:\n print(f\"Error processing sample: {e}\")\n continue\n\n if not batch_data:\n raise ValueError(\"No valid samples in batch\")\n\n # Extract lists\n input_ids_list = [item['input_ids'] for item in batch_data]\n images_seq_mask_list = [item['images_seq_mask'] for item in batch_data]\n prompt_token_counts = [item['prompt_token_count'] for item in batch_data]\n\n # Pad sequences\n input_ids = pad_sequence(input_ids_list, batch_first = True, padding_value = self.tokenizer.pad_token_id)\n images_seq_mask = pad_sequence(images_seq_mask_list, batch_first = True, padding_value = False)\n\n # Create labels\n labels = input_ids.clone()\n\n # Mask padding tokens\n labels[labels == self.tokenizer.pad_token_id] = -100\n\n # Mask image tokens (model shouldn't predict these)\n labels[images_seq_mask] = -100\n\n # Mask user prompt tokens when train_on_responses_only = True (only train on assistant responses)\n if self.train_on_responses_only:\n for idx, prompt_count in enumerate(prompt_token_counts):\n if prompt_count > 0:\n labels[idx, :prompt_count] = -100\n\n # Create attention mask\n attention_mask = (input_ids != self.tokenizer.pad_token_id).long()\n\n # Prepare images batch (list of tuples)\n images_batch = []\n for item in batch_data:\n images_batch.append((item['images_crop'], item['images_ori']))\n\n # Stack spatial crop info\n images_spatial_crop = torch.cat([item['images_spatial_crop'] for item in batch_data], dim = 0)\n\n return {\n \"input_ids\": input_ids,\n \"attention_mask\": attention_mask,\n \"labels\": labels,\n \"images\": images_batch,\n \"images_seq_mask\": images_seq_mask,\n \"images_spatial_crop\": images_spatial_crop,\n }"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Defining a large and complex class like DeepSeekOCR2DataCollator inside a notebook cell harms readability and reusability. This class is also duplicated in other notebooks, which will make maintenance difficult.

Consider moving this class to a separate utils.py file and importing it into the notebooks where it's needed. This will make the notebooks cleaner and easier to follow, and any changes to the data collator will only need to be made in one place.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant