Key Takeaways
torch.compile(): One line, 20–50% throughput improvement. Always use in production training.- Device-agnostic code: Write once, run on CUDA, MPS (Apple Silicon), or CPU.
optimizer.zero_grad()first, always: Gradients accumulate — forgetting this is a top training bug.- Checkpoint often: Save every N epochs. Resume from checkpoint means a power cut doesn’t restart training from epoch 1.
Introduction
Direct Answer: How do I train a neural network with PyTorch 2.x on a local GPU in 2026?
Install: pip install torch torchvision --index-url https://download.pytorch.org/whl/cu124 (CUDA) or pip install torch torchvision (CPU/Mac). Detect device: device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"). Define model with torch.nn.Module. Move to GPU: model = model.to(device). Optimise: model = torch.compile(model). Training loop: optimizer.zero_grad(); outputs = model(inputs.to(device)); loss = criterion(outputs, labels.to(device)); loss.backward(); optimizer.step(). Checkpoint: torch.save({"model": model.state_dict()}, "checkpoint.pt"). Load: model.load_state_dict(torch.load("checkpoint.pt")["model"]).
Part 1: Setup and Device Detection
# Install PyTorch with CUDA 12.4 (NVIDIA GPU)
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
# Install PyTorch for CPU / Apple Silicon (MPS)
pip install torch torchvision torchaudio
# Verify
python3 -c "
import torch
print(f'PyTorch: {torch.__version__}')
print(f'CUDA available: {torch.cuda.is_available()}')
if torch.cuda.is_available():
print(f'GPU: {torch.cuda.get_device_name(0)}')
print(f'VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB')
print(f'MPS available: {torch.backends.mps.is_available()}')
"
Expected output (NVIDIA):
PyTorch: 2.5.1+cu124
CUDA available: True
GPU: NVIDIA GeForce RTX 4090
VRAM: 24.6 GB
MPS available: False
# device_setup.py — device-agnostic template
import torch
def get_device() -> torch.device:
"""Return the best available device."""
if torch.cuda.is_available():
device = torch.device("cuda")
print(f"Using GPU: {torch.cuda.get_device_name(0)}")
elif torch.backends.mps.is_available():
device = torch.device("mps")
print("Using Apple Silicon GPU (MPS)")
else:
device = torch.device("cpu")
print("Using CPU (no GPU detected)")
return device
device = get_device()
Part 2: Define a Neural Network
# model.py — image classification CNN
import torch
import torch.nn as nn
import torch.nn.functional as F
class ConvNet(nn.Module):
"""Simple CNN for image classification (e.g., CIFAR-10)."""
def __init__(self, num_classes: int = 10):
super().__init__()
# Feature extraction layers
self.features = nn.Sequential(
# Block 1: 3x32x32 → 32x16x16
nn.Conv2d(3, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.MaxPool2d(2),
# Block 2: 32x16x16 → 64x8x8
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2),
# Block 3: 64x8x8 → 128x4x4
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(2),
)
# Classification layers
self.classifier = nn.Sequential(
nn.Dropout(0.5),
nn.Linear(128 * 4 * 4, 512),
nn.ReLU(inplace=True),
nn.Dropout(0.3),
nn.Linear(512, num_classes)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.features(x)
x = x.view(x.size(0), -1) # Flatten
return self.classifier(x)
# Inspect model
model = ConvNet(num_classes=10)
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
Expected output:
Total parameters: 2,108,618
Trainable parameters: 2,108,618
Part 3: Training Loop
# train.py — complete training pipeline
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from pathlib import Path
device = get_device()
# ── Data loading ──────────────────────────────────────────────────────────
transform_train = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
transform_val = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform_train)
val_dataset = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform_val)
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=256, shuffle=False, num_workers=4, pin_memory=True)
# ── Model, optimiser, loss ────────────────────────────────────────────────
model = ConvNet(num_classes=10).to(device)
model = torch.compile(model) # PyTorch 2.x: 20-50% speedup, one line
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20)
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
scaler = torch.cuda.amp.GradScaler() # Automatic mixed precision (AMP)
# ── Checkpoint helpers ────────────────────────────────────────────────────
CKPT_DIR = Path("checkpoints")
CKPT_DIR.mkdir(exist_ok=True)
def save_checkpoint(epoch: int, val_acc: float):
torch.save({
"epoch": epoch,
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
"val_acc": val_acc,
}, CKPT_DIR / f"epoch_{epoch:03d}_acc{val_acc:.4f}.pt")
def load_checkpoint(path: str):
ckpt = torch.load(path, map_location=device)
model.load_state_dict(ckpt["model"])
optimizer.load_state_dict(ckpt["optimizer"])
scheduler.load_state_dict(ckpt["scheduler"])
print(f"Resumed from epoch {ckpt['epoch']} (val_acc={ckpt['val_acc']:.4f})")
return ckpt["epoch"]
# ── Training loop ─────────────────────────────────────────────────────────
def train_epoch() -> tuple[float, float]:
model.train()
total_loss, correct, total = 0.0, 0, 0
for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad() # Must be first — gradients accumulate otherwise
# Automatic mixed precision for faster training on NVIDIA GPUs
with torch.autocast(device_type=device.type, dtype=torch.float16):
outputs = model(inputs)
loss = criterion(outputs, labels)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
scaler.step(optimizer)
scaler.update()
total_loss += loss.item() * inputs.size(0)
_, predicted = outputs.max(1)
correct += predicted.eq(labels).sum().item()
total += inputs.size(0)
return total_loss / total, correct / total
def validate() -> tuple[float, float]:
model.eval()
total_loss, correct, total = 0.0, 0, 0
with torch.no_grad():
for inputs, labels in val_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
total_loss += loss.item() * inputs.size(0)
_, predicted = outputs.max(1)
correct += predicted.eq(labels).sum().item()
total += inputs.size(0)
return total_loss / total, correct / total
# ── Main training run ─────────────────────────────────────────────────────
EPOCHS = 20
best_val_acc = 0.0
for epoch in range(1, EPOCHS + 1):
train_loss, train_acc = train_epoch()
val_loss, val_acc = validate()
scheduler.step()
print(f"Epoch {epoch:3d}/{EPOCHS} | "
f"Train: loss={train_loss:.4f} acc={train_acc:.4f} | "
f"Val: loss={val_loss:.4f} acc={val_acc:.4f}")
# Save best model
if val_acc > best_val_acc:
best_val_acc = val_acc
save_checkpoint(epoch, val_acc)
print(f" ✓ New best! Saved checkpoint.")
Expected output:
Epoch 1/20 | Train: loss=1.8234 acc=0.3421 | Val: loss=1.6847 acc=0.3892
Epoch 5/20 | Train: loss=1.1247 acc=0.6134 | Val: loss=1.0983 acc=0.6287
Epoch 10/20 | Train: loss=0.8471 acc=0.7034 | Val: loss=0.9124 acc=0.6943 ← New best! Saved.
Epoch 20/20 | Train: loss=0.6234 acc=0.7891 | Val: loss=0.8341 acc=0.7432 ← New best! Saved.
Part 4: GPU Monitoring
# Monitor GPU during training
watch -n 1 nvidia-smi --query-gpu=utilization.gpu,memory.used,memory.total,temperature.gpu \
--format=csv,noheader,nounits
Expected output during training:
94, 21400, 24564, 68 ← 94% GPU util, 21.4GB/24.6GB VRAM, 68°C
Part 5: Inference
# inference.py — load trained model and classify
import torch
from torchvision import transforms
from PIL import Image
CLASSES = ["airplane","automobile","bird","cat","deer","dog","frog","horse","ship","truck"]
# Load checkpoint
model = ConvNet(num_classes=10).to(device)
ckpt = torch.load("checkpoints/best.pt", map_location=device)
model.load_state_dict(ckpt["model"])
model.eval()
transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
def classify(image_path: str) -> tuple[str, float]:
img = Image.open(image_path).convert("RGB")
tensor = transform(img).unsqueeze(0).to(device) # Add batch dimension
with torch.no_grad():
output = model(tensor)
probs = torch.softmax(output, dim=1)
confidence, predicted = probs.max(1)
return CLASSES[predicted.item()], confidence.item()
label, conf = classify("/tmp/test.jpg")
print(f"Predicted: {label} ({conf*100:.1f}% confidence)")
Conclusion
PyTorch 2.x with torch.compile(), automatic mixed precision, and systematic checkpointing delivers efficient, sovereign deep learning on local GPU hardware. The device-agnostic pattern ensures the same code runs on NVIDIA (CUDA), Apple Silicon (MPS), or CPU without changes.
See Computer Vision with YOLOv11 2026 for a complete pre-trained object detection pipeline, and Machine Learning with scikit-learn 2026 for the simpler non-deep-learning ML baseline.
People Also Ask
When should I use PyTorch vs scikit-learn?
scikit-learn for structured/tabular data: CSV files, database queries, feature engineering, traditional ML algorithms (Random Forest, SVM, Gradient Boosting). PyTorch for unstructured data: images, text sequences, audio, video, or any task where the representation learning (feature extraction) is part of the problem. If your data is a spreadsheet and your problem is classification or regression: scikit-learn. If your data is images, text, or audio: PyTorch. Deep learning generally requires 10× more data and GPU hardware than traditional ML — don’t use it when scikit-learn solves the problem.
How much GPU VRAM do I need for deep learning?
For training small models (our ConvNet): 4GB+ VRAM is sufficient. For fine-tuning a pre-trained ResNet-50 or ViT: 8GB+ recommended. For training Transformer models on text: 16GB+ (BERT-base), 24GB+ (GPT-2 medium). For fine-tuning LLMs with QLoRA: 12GB+ (see QLoRA Fine-tuning Guide). An RTX 3060 12GB handles most vision tasks; an RTX 4090 24GB handles anything except full LLM pre-training.
Part 12: Production Training and Deployment
A practical PyTorch deployment is not just about model code; it is about reproducible training, efficient inference, and observability.
12.1 Environment isolation
Use isolated Python environments for training and deployment. Pin dependencies with poetry.lock, requirements.txt, or conda environment files.
12.2 Reproducibility
Fix random seeds for PyTorch, NumPy, and Python.
import random
import numpy as np
import torch
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
Record the exact CUDA, cuDNN, PyTorch, and driver versions used for training.
12.3 Mixed precision training
Use torch.cuda.amp to reduce memory consumption and improve performance.
scaler = torch.cuda.amp.GradScaler()
with torch.cuda.amp.autocast():
output = model(inputs)
loss = criterion(output, targets)
Mixed precision is now standard for production GPU training.
12.4 Distributed training
Use torch.distributed or higher-level launchers for multi-GPU and multi-node training.
python -m torch.distributed.run --nproc_per_node=4 train.py
Design your training script to support distributed batch sizes and gradient accumulation.
12.5 Checkpointing and model versioning
Checkpoint often and keep a stable best model copy.
torch.save({
'epoch': epoch,
'model_state': model.state_dict(),
'optimizer_state': optimizer.state_dict(),
}, 'checkpoint.pt')
Store model metadata with a hash of the training dataset, seed, and config.
12.6 Profiling and optimization
Profile GPU utilization and operator bottlenecks using PyTorch profiler.
with torch.profiler.profile(
schedule=torch.profiler.schedule(wait=1, warmup=1, active=3),
on_trace_ready=torch.profiler.tensorboard_trace_handler('./log'),
record_shapes=True,
profile_memory=True) as prof:
train_step()
Use profiler data to tune batch size, data pipeline, and model architecture.
Part 13: Efficient Inference
Production inference has different priorities from training: latency, throughput, and cost.
13.1 TorchScript and ONNX
Export models for optimized inference.
traced = torch.jit.trace(model, example_input)
traced.save('model.pt')
Or export to ONNX for runtime portability.
13.2 Quantization
Use dynamic or static quantization to reduce model size and improve latency.
model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
Quantized models are especially useful for CPU inference.
13.3 Batch and request sizing
Choose batch sizes that maximize GPU utilization while meeting latency targets. For low-latency applications, smaller batches may be better.
13.4 Serving with TorchServe
TorchServe provides a production-grade serving layer.
torchserve --start --model-store model_store --models mymodel=mar/model.mar
Use custom handlers for preprocessing and postprocessing to keep the model isolated from business logic.
13.5 Local API design
Wrap inference in an HTTP API with versioned endpoints.
Example request contract:
{ "model": "resnet50", "input": [...], "options": {"top_k": 5} }
This ensures clients do not rely on internal tensor formats.
Part 14: Data Pipeline and Dataset Management
Your model is only as good as the data you feed it.
14.1 Data validation
Validate dataset schema and contents before training. Use checks for missing values, label distribution, and corrupted examples.
14.2 Dataset versioning
Version datasets with tools like DVC or Git LFS. Keep a record of dataset hashes and splits for every experiment.
14.3 Augmentation strategies
Use data augmentation conservatively and test for label-preserving transformations. Overaugmentation can introduce noise and degrade generalization.
Part 15: Debugging and Model Quality
A mature PyTorch workflow includes robust debugging and evaluation.
15.1 Gradient inspection
Check for exploding or vanishing gradients. Log gradient norms, parameter norms, and learning rates.
15.2 Sanity checks
Run small overfitting tests on a tiny dataset to ensure the model can learn. If it cannot, there is a bug in the architecture or training loop.
15.3 Evaluation metrics
Use appropriate metrics for your task. For classification, track precision, recall, F1, and confusion matrices. For regression, use MAE and RMSE.
Part 16: Security and Governance
Machine learning infrastructure must also follow security best practices.
16.1 Access controls
Restrict who can run training jobs and who can deploy models. Use role-based access controls at the compute and data levels.
16.2 Data privacy
Protect personal or regulated data in training datasets. Use anonymization, minimization, and local processing when possible.
16.3 Model auditing
Keep an audit trail for model training runs, hyperparameters, and deployment approvals. This is essential for reproducibility and compliance.
Part 17: Monitoring and Observability
A production ML system needs monitoring like any other service.
17.1 Training metrics
Track training and validation loss, accuracy, and system metrics such as GPU utilization. Store them in a time series database, experiment tracking tool, or ML dashboard.
17.2 Drift detection
Monitor model performance on production data. If accuracy decays or input distributions drift, trigger retraining or investigation.
17.3 Infrastructure alerts
Alert on GPU memory pressure, disk exhaustion, and failed data loads. A blocked training pipeline is as important as a crashed service.
Part 18: Transfer Learning and Fine-Tuning Patterns
Transfer learning remains one of the most efficient ways to build models in 2026.
18.1 Frozen backbone and head tuning
Freeze base model layers and train only the last layers when data is limited.
18.2 Layer-wise learning rates
Use lower learning rates for pretrained layers and higher rates for new layers. This preserves learned representations while adapting to your task.
18.3 Validation strategy
Use cross-validation or held-out test sets. Do not tune hyperparameters on the final test set.
Part 19: Scaling and Cost Management
Training and serving deep learning models can be expensive.
19.1 Spot and preemptible instances
For non-critical training jobs, use spot/preemptible GPU instances to reduce cost. Implement checkpointing so jobs can resume after interruption.
19.2 Mixed workloads
Co-locate training and preprocessing carefully. Avoid competing workloads on the same GPU when possible.
19.3 Resource quotas
Set quota policies to prevent runaway jobs from consuming all available GPUs or storage.
Part 20: Final Operational Checklist
- training and inference environments are isolated and versioned
- data pipelines are validated and version controlled
- production inference is served with a stable API contract
- metrics for training and inference are captured and alerted on
- model checkpoints are stored with metadata and hashes
- security controls protect training data and serving endpoints
- experiment results are logged and reviewed
- reproducibility is enforced with fixed seeds and dependency tracking
A production-ready PyTorch workflow is not just code; it is a set of engineered practices that keep models reliable, efficient, and manageable in a self-hosted environment.
Part 21: Model Explainability and Responsible AI
Explainability is increasingly important for production machine learning.
21.1 Feature importance and model introspection
Use tools such as SHAP or Captum to explain predictions. Capture explanations alongside model predictions for debugging.
21.2 Bias detection
Test models for bias across relevant subgroups. Measure performance gaps and track whether model changes reduce or increase bias.
21.3 Responsible deployment
Only deploy models with appropriate guardrails. For high-risk tasks, include a human review step or a safety filter.
Part 22: Experimentation and Search
Hyperparameter search helps you find the best model configuration.
22.1 Grid and random search
Use grid search for small parameter spaces and random search for larger spaces. Log every trial with its metrics and config.
22.2 Bayesian optimization
For more efficient searches, use Bayesian optimization frameworks such as Optuna or Ray Tune.
22.3 Early stopping
Use early stopping on validation metrics to avoid wasting compute on models that already stopped improving.
Part 23: Data Pipeline Optimization
The data pipeline is often the slowest part of training.
23.1 Prefetching and caching
Use DataLoader with prefetch_factor and num_workers to keep GPUs fed.
dataloader = DataLoader(dataset, batch_size=32, num_workers=8, pin_memory=True, prefetch_factor=2)
23.2 Mixed storage and data locality
Store frequently used data on fast local SSDs. Use network storage sparingly for datasets that are accessed repeatedly.
23.3 Data sanity checks at ingest
Validate batch shape, type, and value ranges during data ingestion. Fail early if corruption is detected.
Part 24: Deployment Patterns for Edge and Cloud
PyTorch models can be deployed in a wide range of environments.
24.1 Local edge inference
For edge devices, optimize models for CPU or small GPUs. Use TorchScript or ONNX for portable runtimes.
24.2 Containerized inference services
Package inference code and model artifacts into a container. Keep the image small and the runtime pinned.
24.3 Hybrid deployments
Use a local inference edge for latency-sensitive requests and a centralized server for heavy batch workloads. This balances performance and resource utilization.
Part 25: Security and Dependency Hygiene
Keep your training and inference stack secure.
25.1 Dependency pinning
Pin all Python dependencies and update them in a controlled process. Use tools such as pip-audit to identify vulnerable packages.
25.2 Secrets management
Store credentials for data storage, experiment tracking, and monitoring outside of code. Use environment variables or a secrets manager.
25.3 Runtime isolation
Run model serving in a sandboxed environment with minimal privileges. Do not expose debug or developer ports in production.
Part 26: Scaling the ML Platform
A growing model platform needs a scalable foundation.
26.1 Job orchestration
Use workflow orchestration tools like Airflow, Prefect, or Dagster to manage data prep and training jobs.
26.2 Resource quotas and scheduling
Define quotas for GPU and CPU resources. Schedule large jobs during off-peak hours when possible.
26.3 Model lifecycle management
Track model versions from experiment to staging to production. Use a registry or metadata store to manage the lifecycle.
Part 27: Model Compression and Optimization
Optimizing models for production is essential for cost and latency.
27.1 Pruning and sparsity
Prune weights that contribute little to model accuracy. Use structured pruning for more efficient inference on modern hardware.
27.2 Knowledge distillation
Use distillation to train smaller models that mimic larger ones. This is especially valuable for edge and CPU deployments.
27.3 Operator fusion and custom kernels
Leverage fused operators and custom CUDA or CPU kernels for critical paths. This can significantly reduce latency for convolution-heavy or transformer workloads.
Part 28: Serving Patterns for Real-Time Inference
Real-time inference demands reliability and low latency.
28.1 Asynchronous request handling
Use asynchronous serving frameworks or thread pools to keep the model server responsive under concurrent load. Avoid blocking the main event loop on long CPU-bound requests.
28.2 Warmup requests
Warm up the model on startup with representative inputs. This avoids the first-request latency spike that occurs when caches are cold.
28.3 Graceful degradation
Implement fallback logic for overloaded inference servers. Return a cached result, a smaller model response, or a user-visible retry prompt rather than failing outright.
Part 29: Experiment Tracking and Governance
Track experiments to turn raw training runs into reliable products.
29.1 Metadata capture
Capture metadata for each experiment: dataset version, hyperparameters, training hardware, random seed, and code commit hash.
29.2 Experiment dashboards
Use an experiment tracking tool to compare runs. Track validation curves, loss trajectories, and early stopping events.
29.3 Approval gates
Require a review before promoting a model from staging to production. Evaluate not just accuracy but behavior on edge cases and safety metrics.
Part 30: Distributed Team Collaboration
PyTorch work often spans research, engineering, and operations.
30.1 Shared reproducible examples
Keep small, reproducible examples in the repository for common workflows: training, evaluation, exporting, and serving.
30.2 Code review standards
Review model code for tensor shapes, device handling, numerical stability, and data poisoning risks. Treat model changes with the same rigor as backend services.
30.3 Knowledge transfer
Document not just how the model works, but why certain architecture and training choices were made. This helps future maintainers understand the tradeoffs.
Part 31: Final Production Validation
Before promoting a model to production, validate it with real-world data and a clear acceptance criteria.
31.1 Shadow testing
Run the new model in shadow mode against real requests while the old model remains active. Compare outputs and performance without affecting users.
31.2 Performance regression checks
Ensure the new model does not regress latency, throughput, or memory usage. Validate that the hardware it will run on can support it under peak load.
31.3 Deployment checklist
- model artifact is versioned and checksummed
- serving environment is configured with pinned runtime versions
- rollback path is tested
- monitoring is in place for both quality and infrastructure metrics
- data drift alerts are configured
Further Reading
- Computer Vision with YOLOv11 2026 — pre-trained object detection pipeline
- Machine Learning with scikit-learn 2026 — simpler ML baseline
- Fine-Tune Llama 4 with QLoRA 2026 — LLM fine-tuning with PyTorch
- On-Device AI Inference 2026 — hardware selection for training
Tested on: Ubuntu 24.04 LTS (RTX 4090, CUDA 12.4), macOS Sequoia 15.4 (M3 Max). PyTorch 2.5.1. Last verified: May 1, 2026.