Feature: Port and Optimize Unsloth Notebooks for Kaggle Dual T4 GPUs (T4x2)#212
Feature: Port and Optimize Unsloth Notebooks for Kaggle Dual T4 GPUs (T4x2)#212AAB20 wants to merge 3 commits intounslothai:mainfrom
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request ports and optimizes a collection of Unsloth AI notebooks for Kaggle's dual T4 GPU environment. The changes enable efficient fine-tuning and inference for various advanced AI models, including LLMs, RL, and multimodal embeddings, making high-performance AI development more accessible to researchers and developers on the Kaggle platform. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a valuable set of notebooks for various models on Kaggle's T4x2 environment. The optimizations for this specific hardware are a great addition.
My main feedback is regarding the significant amount of duplicated code across the notebooks. Many helper functions (like load_with_retry), configuration blocks, and setup cells are copied in multiple files. This will make future maintenance difficult, as any change will need to be manually propagated to all notebooks. I strongly recommend refactoring this common code into a shared Python utility script that can be imported into each notebook. This would greatly improve the maintainability and readability of the project.
I've left a few specific comments on instances of code duplication and other potential improvements for correctness and maintainability.
| "execution_count": null, | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": "from datasets import load_dataset, load_from_disk\nimport os, time\n\nDATASET_CACHE = \"/kaggle/working/_dataset_cache\"\n\ndef load_with_retry(path, split=\"train\", retries=5, wait=15):\n if os.path.exists(DATASET_CACHE):\n print(f\"Loading from cache: {DATASET_CACHE}\")\n return load_from_disk(DATASET_CACHE)\n for i in range(retries):\n try:\n print(f\"Attempt {i+1}/{retries}: {path}\")\n ds = load_dataset(path, split=split)\n ds.save_to_disk(DATASET_CACHE)\n print(f\"Loaded OK — {len(ds)} rows\")\n return ds\n except Exception as e:\n print(f\" Failed: {e}\")\n if i < retries - 1:\n time.sleep(wait)\n raise RuntimeError(f\"Failed to load {path} after {retries} attempts\")" |
There was a problem hiding this comment.
This load_with_retry function is duplicated across multiple notebooks. To improve maintainability, consider moving this and other shared utility functions into a common Python script that can be imported by all notebooks.
Additionally, the caching logic is not robust. It uses a fixed DATASET_CACHE path regardless of the dataset path argument. This can lead to cache collisions if different datasets are used across notebooks in the same environment. A better approach would be to create a unique cache directory for each dataset, for example by using a hash of the dataset path.
Example for a more robust caching mechanism:
import hashlib
def get_cache_path(dataset_path):
hash_object = hashlib.md5(dataset_path.encode())
return f"/kaggle/working/_dataset_cache_{hash_object.hexdigest()}"
# In load_with_retry:
dataset_cache_path = get_cache_path(path)
if os.path.exists(dataset_cache_path):
# ... load from disk
...
ds.save_to_disk(dataset_cache_path)| "id": "GjBFrttr-y1_" | ||
| }, | ||
| "outputs": [], | ||
| "source": "global PRINTED_TIMES\nPRINTED_TIMES = 0\nglobal PRINT_EVERY_STEPS\nPRINT_EVERY_STEPS = 5\n\ndef check_numbers(prompts, completions, answer, **kwargs):\n question = prompts[0][-1][\"content\"]\n responses = [completion[0][\"content\"] for completion in completions]\n\n extracted_responses = [\n guess.group(1)\n if (guess := match_numbers.search(r)) is not None else None \\\n for r in responses\n ]\n\n scores = []\n # Print only every few steps\n global PRINTED_TIMES\n global PRINT_EVERY_STEPS\n if PRINTED_TIMES % PRINT_EVERY_STEPS == 0:\n print('*'*20, f\"Question:\\n{question}\", f\"\\nAnswer:\\n{answer[0]}\", f\"\\nResponse:\\n{responses[0]}\", f\"\\nExtracted:\\n{extracted_responses[0]}\")\n PRINTED_TIMES += 1\n\n for guess, true_answer in zip(extracted_responses, answer):\n if guess is None:\n scores.append(0)\n continue\n # Convert to numbers\n try:\n true_answer = float(true_answer.strip())\n # Remove commas like in 123,456\n guess = float(guess.strip().replace(\",\", \"\"))\n scores.append(1.5 if guess == true_answer else -0.5)\n except:\n scores.append(0)\n continue\n return scores" |
There was a problem hiding this comment.
The use of global variables (PRINTED_TIMES, PRINT_EVERY_STEPS) makes this function stateful and harder to test and reuse. This can lead to unexpected behavior if the function is called multiple times or from different places.
A better approach would be to manage this state within a class or by using a closure. For example, you could wrap the reward function in a class:
class RewardLogger:
def __init__(self, print_every_steps=5):
self.printed_times = 0
self.print_every_steps = print_every_steps
def check_numbers(self, prompts, completions, answer, **kwargs):
# ... (rest of the function logic)
if self.printed_times % self.print_every_steps == 0:
# ... print statement
self.printed_times += 1
# ...
return scores
reward_logger = RewardLogger()
# in GRPOTrainer:
# reward_funcs = [..., reward_logger.check_numbers]This encapsulates the state and makes the code cleaner and more maintainable.
| "execution_count": null, | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": "import gc, torch\n\n# ── Patch trainer before training ─────────────────────────\ntry:\n trainer.args.per_device_train_batch_size = 1\n trainer.args.gradient_accumulation_steps = 32\n trainer.args.gradient_checkpointing = False\n trainer.args.fp16 = True\n trainer.args.bf16 = False\n trainer.args.dataloader_pin_memory = True\n trainer.args.dataloader_num_workers = 2\n trainer.args.ddp_find_unused_parameters = False\n trainer.args.max_grad_norm = 0.3\n eff = trainer.args.per_device_train_batch_size * trainer.args.gradient_accumulation_steps\n print(f\"TrainingArguments patched — effective batch: {eff}\")\nexcept Exception as e:\n print(f\"TrainingArguments skipped: {e}\")\n\ntry:\n _m = trainer.model\n while hasattr(_m, \"module\"):\n _m = _m.module\n if hasattr(_m, \"enable_input_require_grads\"):\n _m.enable_input_require_grads()\n print(\"enable_input_require_grads OK\")\nexcept Exception as e:\n print(f\"enable_input_require_grads skipped: {e}\")\n\ngc.collect()\ntorch.cuda.empty_cache()\nfor i in range(torch.cuda.device_count()):\n print(f\" GPU {i} free: {round(torch.cuda.mem_get_info(i)[0]/1024**3,2)} GiB\")\nprint(\"Ready to train\")" |
There was a problem hiding this comment.
Patching trainer.args after the GRPOTrainer has been initialized can be confusing and error-prone. The arguments are used during the trainer's __init__. Modifying them afterwards might not have the intended effect for all parameters.
It would be clearer and safer to configure all training arguments directly in the GRPOConfig object before passing it to the trainer. This makes the configuration explicit and centralized.
For example, move the patched arguments to the GRPOConfig in the preceding cell:
training_args = GRPOConfig(
learning_rate = 5e-6,
# ... other args
per_device_train_batch_size = 1,
gradient_accumulation_steps = 32,
gradient_checkpointing = False,
fp16 = True,
bf16 = False,
dataloader_pin_memory = True,
dataloader_num_workers = 2,
ddp_find_unused_parameters = False,
max_grad_norm = 0.3,
# ...
)| "execution_count": null, | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": "\nimport os, re\nif \"COLAB_\" not in \"\".join(os.environ.keys()):\n !pip install -q unsloth # Do this in local & cloud setups\nelse:\n import torch; v = re.match(r'[\\d]{1,}\\.[\\d]{1,}', str(torch.__version__)).group(0)\n xformers = 'xformers==' + {'2.10':'0.0.34','2.9':'0.0.33.post1','2.8':'0.0.32.post2'}.get(v, \"0.0.34\")\n !pip install -q sentencepiece protobuf \"datasets==4.3.0\" \"huggingface_hub>=0.34.0\" hf_transfer\n !pip install -q --no-deps unsloth_zoo bitsandbytes accelerate {xformers} peft trl triton unsloth\n!pip install -q transformers==4.56.2\n!pip install -q --no-deps trl==0.22.2" |
There was a problem hiding this comment.
The regex used to parse the torch version is fragile and might fail on development or more complex version strings (e.g., 2.1.0.dev20230310+cu118). A more robust way to get the major.minor version would be to use string splitting, which is also more readable.
For example:
import torch
v = '.'.join(torch.__version__.split('.')[:2])| "id": "p31Z-S6FUieB" | ||
| }, | ||
| "outputs": [], | ||
| "source": "unsloth_template = \\\n \"{{ bos_token }}\"\\\n \"{{ 'You are a helpful assistant to the user\\n' }}\"\\\n \"{% endif %}\"\\\n \"{% for message in messages %}\"\\\n \"{% if message['role'] == 'user' %}\"\\\n \"{{ '>>> User: ' + message['content'] + '\\n' }}\"\\\n \"{% elif message['role'] == 'assistant' %}\"\\\n \"{{ '>>> Assistant: ' + message['content'] + eos_token + '\\n' }}\"\\\n \"{% endif %}\"\\\n \"{% endfor %}\"\\\n \"{% if add_generation_prompt %}\"\\\n \"{{ '>>> Assistant: ' }}\"\\\n \"{% endif %}\"\nunsloth_eos_token = \"eos_token\"\n\nif False:\n tokenizer = get_chat_template(\n tokenizer,\n chat_template = (unsloth_template, unsloth_eos_token,), # You must provide a template and EOS token\n mapping = {\"role\" : \"from\", \"content\" : \"value\", \"user\" : \"human\", \"assistant\" : \"gpt\"}, # ShareGPT style\n map_eos_token = True, # Maps <|im_end|> to </s> instead\n )" |
There was a problem hiding this comment.
This Jinja template has a syntax error. The {% endif %} on the third line of the template string does not have a corresponding {% if ... %} block. This will raise a TemplateSyntaxError if this code block is executed. Although it's currently under an if False block, it's best to fix it for future use. It seems to be a copy-paste error and should likely be removed.
| "id": "E2WR-p20LcG_" | ||
| }, | ||
| "outputs": [], | ||
| "source": "# @title Create datacollator\n\nimport torch\nimport math\nfrom dataclasses import dataclass\nfrom typing import Dict, List, Any, Tuple\nfrom PIL import Image, ImageOps\nfrom torch.nn.utils.rnn import pad_sequence\nimport io\n\nfrom deepseek_ocr2.modeling_deepseekocr2 import (\n format_messages,\n text_encode,\n BasicImageTransform,\n dynamic_preprocess,\n)\n\n@dataclass\nclass DeepSeekOCR2DataCollator:\n \"\"\"\n Args:\n tokenizer: Tokenizer\n model: Model\n image_size: Size for image patches (default: 768)\n base_size: Size for global view (default: 1024)\n crop_mode: Whether to use dynamic cropping for large images\n train_on_responses_only: If True, only train on assistant responses (mask user prompts)\n \"\"\"\n tokenizer: Any\n model: Any\n image_size: int = 768\n base_size: int = 1024\n crop_mode: bool = True\n image_token_id: int = 128815\n train_on_responses_only: bool = True\n\n def __init__(\n self,\n tokenizer,\n model,\n image_size: int = 768,\n base_size: int = 1024,\n crop_mode: bool = True,\n train_on_responses_only: bool = True,\n ):\n self.tokenizer = tokenizer\n self.model = model\n self.image_size = image_size\n self.base_size = base_size\n self.crop_mode = crop_mode\n self.image_token_id = 128815\n self.dtype = model.dtype # Get dtype from model\n self.train_on_responses_only = train_on_responses_only\n\n self.image_transform = BasicImageTransform(\n mean = (0.5, 0.5, 0.5),\n std = (0.5, 0.5, 0.5),\n normalize = True\n )\n self.patch_size = 16\n self.downsample_ratio = 4\n\n # Get BOS token ID from tokenizer\n if hasattr(tokenizer, 'bos_token_id') and tokenizer.bos_token_id is not None:\n self.bos_id = tokenizer.bos_token_id\n else:\n self.bos_id = 0\n print(f\"Warning: tokenizer has no bos_token_id, using default: {self.bos_id}\")\n\n def deserialize_image(self, image_data) -> Image.Image:\n \"\"\"Convert image data (bytes dict or PIL Image) to PIL Image in RGB mode\"\"\"\n if isinstance(image_data, Image.Image):\n return image_data.convert(\"RGB\")\n elif isinstance(image_data, dict) and 'bytes' in image_data:\n image_bytes = image_data['bytes']\n image = Image.open(io.BytesIO(image_bytes))\n return image.convert(\"RGB\")\n else:\n raise ValueError(f\"Unsupported image format: {type(image_data)}\")\n\n def calculate_image_token_count(self, image: Image.Image, crop_ratio: Tuple[int, int]) -> int:\n \"\"\"Calculate the number of tokens this image will generate\"\"\"\n num_queries = math.ceil((self.image_size // self.patch_size) / self.downsample_ratio)\n num_queries_base = math.ceil((self.base_size // self.patch_size) / self.downsample_ratio)\n\n width_crop_num, height_crop_num = crop_ratio\n\n if self.crop_mode:\n img_tokens = num_queries_base * num_queries_base + 1\n if width_crop_num > 1 or height_crop_num > 1:\n img_tokens += (num_queries * width_crop_num) * (num_queries * height_crop_num)\n else:\n img_tokens = num_queries * num_queries + 1\n\n return img_tokens\n\n def process_image(self, image: Image.Image) -> Tuple[List, List, List, List, Tuple[int, int]]:\n \"\"\"\n Process a single image based on crop_mode and size thresholds\n\n Returns:\n Tuple of (images_list, images_crop_list, images_spatial_crop, tokenized_image, crop_ratio)\n \"\"\"\n images_list = []\n images_crop_list = []\n images_spatial_crop = []\n\n if self.crop_mode:\n # Determine crop ratio based on image size\n if image.size[0] <= 768 and image.size[1] <= 768:\n crop_ratio = (1, 1)\n images_crop_raw = []\n else:\n images_crop_raw, crop_ratio = dynamic_preprocess(\n image, min_num = 2, max_num = 6,\n image_size = self.image_size, use_thumbnail = False\n )\n\n # Process global view with padding\n global_view = ImageOps.pad(\n image, (self.base_size, self.base_size),\n color = tuple(int(x * 255) for x in self.image_transform.mean)\n )\n images_list.append(self.image_transform(global_view).to(self.dtype))\n\n width_crop_num, height_crop_num = crop_ratio\n images_spatial_crop.append([width_crop_num, height_crop_num])\n\n # Process local views (crops) if applicable\n if width_crop_num > 1 or height_crop_num > 1:\n for crop_img in images_crop_raw:\n images_crop_list.append(\n self.image_transform(crop_img).to(self.dtype)\n )\n\n # Calculate image tokens\n num_queries = math.ceil((self.image_size // self.patch_size) / self.downsample_ratio)\n num_queries_base = math.ceil((self.base_size // self.patch_size) / self.downsample_ratio)\n\n tokenized_image = ([self.image_token_id] * num_queries_base) * num_queries_base\n tokenized_image += [self.image_token_id]\n\n if width_crop_num > 1 or height_crop_num > 1:\n tokenized_image += ([self.image_token_id] * (num_queries * width_crop_num)) * (\n num_queries * height_crop_num)\n\n else: # crop_mode = False\n crop_ratio = (1, 1)\n images_spatial_crop.append([1, 1])\n\n # For smaller base sizes, resize; for larger, pad\n if self.base_size <= 768:\n resized_image = image.resize((self.base_size, self.base_size), Image.LANCZOS)\n images_list.append(self.image_transform(resized_image).to(self.dtype))\n else:\n global_view = ImageOps.pad(\n image, (self.base_size, self.base_size),\n color = tuple(int(x * 255) for x in self.image_transform.mean)\n )\n images_list.append(self.image_transform(global_view).to(self.dtype))\n\n num_queries = math.ceil((self.base_size // self.patch_size) / self.downsample_ratio)\n tokenized_image = ([self.image_token_id] * num_queries) * num_queries\n tokenized_image += [self.image_token_id]\n\n return images_list, images_crop_list, images_spatial_crop, tokenized_image, crop_ratio\n\n def process_single_sample(self, messages: List[Dict]) -> Dict[str, Any]:\n \"\"\"\n Process a single conversation into model inputs.\n \"\"\"\n\n # --- 1. Setup ---\n images = []\n for message in messages:\n if \"images\" in message and message[\"images\"]:\n for img_data in message[\"images\"]:\n if img_data is not None:\n pil_image = self.deserialize_image(img_data)\n images.append(pil_image)\n\n if not images:\n raise ValueError(\"No images found in sample. Please ensure all samples contain images.\")\n\n tokenized_str = []\n images_seq_mask = []\n images_list, images_crop_list, images_spatial_crop = [], [], []\n\n prompt_token_count = -1 # Index to start training\n assistant_started = False\n image_idx = 0\n\n # Add BOS token at the very beginning\n tokenized_str.append(self.bos_id)\n images_seq_mask.append(False)\n\n for message in messages:\n role = message[\"role\"]\n content = message[\"content\"]\n\n # Check if this is the assistant's turn\n if role == \"<|Assistant|>\":\n if not assistant_started:\n # This is the split point. All tokens added *so far*\n # are part of the prompt.\n prompt_token_count = len(tokenized_str)\n assistant_started = True\n\n # Append the EOS token string to the *end* of assistant content\n content = f\"{content.strip()} {self.tokenizer.eos_token}\"\n\n # Split this message's content by the image token\n text_splits = content.split('<image>')\n\n for i, text_sep in enumerate(text_splits):\n # Tokenize the text part\n tokenized_sep = text_encode(self.tokenizer, text_sep, bos = False, eos = False)\n tokenized_str.extend(tokenized_sep)\n images_seq_mask.extend([False] * len(tokenized_sep))\n\n # If this text is followed by an <image> tag\n if i < len(text_splits) - 1:\n if image_idx >= len(images):\n raise ValueError(\n f\"Data mismatch: Found '<image>' token but no corresponding image.\"\n )\n\n # Process the image\n image = images[image_idx]\n img_list, crop_list, spatial_crop, tok_img, _ = self.process_image(image)\n\n images_list.extend(img_list)\n images_crop_list.extend(crop_list)\n images_spatial_crop.extend(spatial_crop)\n\n # Add image placeholder tokens\n tokenized_str.extend(tok_img)\n images_seq_mask.extend([True] * len(tok_img))\n\n image_idx += 1 # Move to the next image\n\n # --- 3. Validation and Final Prep ---\n if image_idx != len(images):\n raise ValueError(\n f\"Data mismatch: Found {len(images)} images but only {image_idx} '<image>' tokens were used.\"\n )\n\n # If we never found an assistant message, we're in a weird state\n # (e.g., user-only prompt). We mask everything.\n if not assistant_started:\n print(\"Warning: No assistant message found in sample. Masking all tokens.\")\n prompt_token_count = len(tokenized_str)\n\n # Prepare image tensors\n images_ori = torch.stack(images_list, dim = 0)\n images_spatial_crop_tensor = torch.tensor(images_spatial_crop, dtype = torch.long)\n\n if images_crop_list:\n images_crop = torch.stack(images_crop_list, dim = 0)\n else:\n images_crop = torch.zeros((1, 3, self.base_size, self.base_size), dtype = self.dtype)\n\n return {\n \"input_ids\": torch.tensor(tokenized_str, dtype = torch.long),\n \"images_seq_mask\": torch.tensor(images_seq_mask, dtype = torch.bool),\n \"images_ori\": images_ori,\n \"images_crop\": images_crop,\n \"images_spatial_crop\": images_spatial_crop_tensor,\n \"prompt_token_count\": prompt_token_count, # This is now accurate\n }\n\n def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:\n \"\"\"Collate batch of samples\"\"\"\n batch_data = []\n\n # Process each sample\n for feature in features:\n try:\n processed = self.process_single_sample(feature['messages'])\n batch_data.append(processed)\n except Exception as e:\n print(f\"Error processing sample: {e}\")\n continue\n\n if not batch_data:\n raise ValueError(\"No valid samples in batch\")\n\n # Extract lists\n input_ids_list = [item['input_ids'] for item in batch_data]\n images_seq_mask_list = [item['images_seq_mask'] for item in batch_data]\n prompt_token_counts = [item['prompt_token_count'] for item in batch_data]\n\n # Pad sequences\n input_ids = pad_sequence(input_ids_list, batch_first = True, padding_value = self.tokenizer.pad_token_id)\n images_seq_mask = pad_sequence(images_seq_mask_list, batch_first = True, padding_value = False)\n\n # Create labels\n labels = input_ids.clone()\n\n # Mask padding tokens\n labels[labels == self.tokenizer.pad_token_id] = -100\n\n # Mask image tokens (model shouldn't predict these)\n labels[images_seq_mask] = -100\n\n # Mask user prompt tokens when train_on_responses_only = True (only train on assistant responses)\n if self.train_on_responses_only:\n for idx, prompt_count in enumerate(prompt_token_counts):\n if prompt_count > 0:\n labels[idx, :prompt_count] = -100\n\n # Create attention mask\n attention_mask = (input_ids != self.tokenizer.pad_token_id).long()\n\n # Prepare images batch (list of tuples)\n images_batch = []\n for item in batch_data:\n images_batch.append((item['images_crop'], item['images_ori']))\n\n # Stack spatial crop info\n images_spatial_crop = torch.cat([item['images_spatial_crop'] for item in batch_data], dim = 0)\n\n return {\n \"input_ids\": input_ids,\n \"attention_mask\": attention_mask,\n \"labels\": labels,\n \"images\": images_batch,\n \"images_seq_mask\": images_seq_mask,\n \"images_spatial_crop\": images_spatial_crop,\n }" |
There was a problem hiding this comment.
Defining a large and complex class like DeepSeekOCR2DataCollator inside a notebook cell harms readability and reusability. This class is also duplicated in other notebooks, which will make maintenance difficult.
Consider moving this class to a separate utils.py file and importing it into the notebooks where it's needed. This will make the notebooks cleaner and easier to follow, and any changes to the data collator will only need to be made in one place.
Overview
This Pull Request introduces a comprehensive adaptation of the official Unsloth AI notebooks. The original work, which was designed primarily for Google Colab environments, has been fully refactored and optimized to run seamlessly on Kaggle's platform.
Key Modifications & Enhancements
Supported Tasks & Models
This updated directory now includes fully functional notebooks covering a wide spectrum of advanced AI development, including:
Motivation
By porting these crucial resources to Kaggle, this project aims to democratize access to high-performance AI fine-tuning. It provides developers and researchers a ready-to-use template to leverage Kaggle's free multi-GPU compute, significantly expanding the practical utility of the Unsloth ecosystem.