Skip to main content

Build and train the CNN model

Training effective neural networks requires careful architecture design, configurable hyperparameters, and robust training loops. Our CNN implementation uses modern best practices including batch normalization, dropout regularization, and adaptive learning rate scheduling to achieve reliable digit classification performance.

Configurable model architecture

The DigitCNN class implements a three-layer convolutional neural network designed specifically for MNIST's 28x28 grayscale images. The architecture follows the principle of progressive feature abstraction—early layers detect edges and simple patterns, while deeper layers combine these into complex shapes for final classification:

src/project_ml/defs/assets/model_assets.py
class DigitCNN(nn.Module):
"""Improved CNN for MNIST digit classification based on research."""

def __init__(self, config: ModelConfig = None):
super().__init__()
if config is None:
config = ModelConfig()

self.config = config

# First convolutional block
self.conv1 = nn.Conv2d(
1, config.conv1_channels, kernel_size=5, padding=2
) # 5x5 kernel, maintain size
self.bn1 = nn.BatchNorm2d(config.conv1_channels) if config.use_batch_norm else nn.Identity()
self.pool1 = nn.MaxPool2d(2, 2) # 28x28 -> 14x14

# Second convolutional block
self.conv2 = nn.Conv2d(
config.conv1_channels, config.conv2_channels, kernel_size=5, padding=2
)
self.bn2 = nn.BatchNorm2d(config.conv2_channels) if config.use_batch_norm else nn.Identity()
self.pool2 = nn.MaxPool2d(2, 2) # 14x14 -> 7x7

# Third convolutional block (new)
self.conv3 = nn.Conv2d(
config.conv2_channels, config.conv3_channels, kernel_size=3, padding=1
)
self.bn3 = nn.BatchNorm2d(config.conv3_channels) if config.use_batch_norm else nn.Identity()
self.pool3 = nn.AdaptiveAvgPool2d((3, 3)) # Adaptive pooling to 3x3

# Dropout layers
self.dropout1 = nn.Dropout2d(config.dropout1_rate)
self.dropout2 = nn.Dropout(config.dropout2_rate)

# Calculate the flattened size: 3x3 * conv3_channels
conv_output_size = 3 * 3 * config.conv3_channels

# Fully connected layers
self.fc1 = nn.Linear(conv_output_size, config.hidden_size)
self.fc2 = nn.Linear(config.hidden_size, config.hidden_size // 2) # Additional FC layer
self.fc3 = nn.Linear(config.hidden_size // 2, 10)

def _conv_block(self, x, conv, bn, pool, dropout=None):
"""Apply a convolutional block: conv -> bn -> relu -> pool -> dropout (optional)."""
x = conv(x)
x = bn(x)
x = F.relu(x)
x = pool(x)
if dropout is not None:
x = dropout(x)
return x

def _fc_block(self, x, fc, dropout=None):
"""Apply a fully connected block: linear -> relu -> dropout (optional)."""
x = fc(x)
x = F.relu(x)
if dropout is not None:
x = dropout(x)
return x

def forward(self, x):
"""Forward pass through the CNN architecture.

Input: (batch_size, 1, 28, 28) - MNIST digit images
Output: (batch_size, 10) - Raw logits for 10 digit classes

Architecture flow:
1. Conv1: 28x28 -> 14x14 (5x5 kernel, 32 channels)
2. Conv2: 14x14 -> 7x7 (5x5 kernel, 64 channels) + spatial dropout
3. Conv3: 7x7 -> 3x3 (3x3 kernel, 128 channels, adaptive pooling)
4. Flatten: 3x3*128 = 1152 features
5. FC layers: 1152 -> 256 -> 128 -> 10 (with dropout)
"""
# Convolutional layers with progressive downsampling
x = self._conv_block(x, self.conv1, self.bn1, self.pool1)
x = self._conv_block(x, self.conv2, self.bn2, self.pool2, self.dropout1)
x = self._conv_block(x, self.conv3, self.bn3, self.pool3)

# Flatten spatial dimensions for fully connected layers
x = torch.flatten(x, 1) # Keep batch dimension

# Fully connected layers with progressive feature reduction
x = self._fc_block(x, self.fc1, self.dropout2)
x = self._fc_block(x, self.fc2)
x = self.fc3(x) # Final layer - no activation (raw logits)

return x # Return raw logits for CrossEntropyLoss

The architecture demonstrates key design principles: progressive downsampling reduces spatial dimensions while increasing feature depth (28×28 → 14×14 → 7×7 → 3×3), batch normalization after each convolution stabilizes training and enables higher learning rates, and strategic dropout prevents overfitting on spatial patterns. The configurable design allows easy experimentation with different channel sizes, dropout rates, and architectural components through the ModelConfig system.

Training configuration system

Rather than hardcoding training parameters, the system uses Dagster's configuration framework to enable experimentation without code modifications. The ModelConfig class centralizes all training hyperparameters, from model architecture to optimization strategies:

src/project_ml/defs/assets/model_assets.py
class ModelConfig(dg.Config):
"""Configuration for model architecture and training."""

# Architecture parameters
conv1_channels: int = 32 # Reduced complexity
conv2_channels: int = 64
conv3_channels: int = 128
dropout1_rate: float = 0.1 # Reduced dropout
dropout2_rate: float = 0.2
hidden_size: int = 256
use_batch_norm: bool = True

# Training parameters
batch_size: int = DEFAULT_BATCH_SIZE # Smaller batch size for better generalization
learning_rate: float = DEFAULT_LEARNING_RATE # Reduced learning rate
epochs: int = DEFAULT_EPOCHS # Increased epochs
optimizer_type: str = "adam" # Changed to Adam
momentum: float = 0.9
weight_decay: float = 1e-5 # Reduced weight decay

# Learning rate scheduling
use_lr_scheduler: bool = True
lr_step_size: int = LR_STEP_SIZE
lr_gamma: float = LR_GAMMA

# Early stopping
use_early_stopping: bool = True
patience: int = EARLY_STOPPING_PATIENCE
min_delta: float = MIN_DELTA

# Data augmentation - Research proven techniques
use_data_augmentation: bool = True
rotation_degrees: float = 15.0 # Increased from 10
translation_pixels: float = 0.1 # New parameter
scale_range_min: float = 0.9 # Split tuple into two floats
scale_range_max: float = 1.1 # Split tuple into two floats

# Model saving
save_model: bool = True
model_save_dir: str = str(MODELS_DIR)
model_name_prefix: str = "mnist_cnn"

This configuration approach separates model architecture from training strategy, enabling data scientists to experiment with different hyperparameters through configuration files while keeping the underlying training logic stable. The configuration includes advanced features like learning rate scheduling (StepLR), early stopping with patience, multiple optimizer support, and automatic model persistence with descriptive filenames.

Training asset orchestration

The digit_classifier asset coordinates the entire training process, from data loading through model persistence. This asset demonstrates how Dagster assets can orchestrate complex ML workflows while providing comprehensive logging and metadata generation:

src/project_ml/defs/assets/model_assets.py
@dg.asset(
description="Train CNN digit classifier with configurable parameters",
group_name="model_pipeline",
required_resource_keys={"model_storage"},
)
def digit_classifier(
context,
processed_mnist_data: dict[str, torch.Tensor],
config: ModelConfig,
) -> DigitCNN:
"""Train a CNN to classify handwritten digits 0-9 with flexible configuration."""
context.log.info(f"Training with config: {config.model_dump()}")

train_data = processed_mnist_data["train_data"]
val_data = processed_mnist_data["val_data"]
train_labels = processed_mnist_data["train_labels"]
val_labels = processed_mnist_data["val_labels"]

# Create data loaders with configurable batch size
train_dataset = TensorDataset(train_data, train_labels)
val_dataset = TensorDataset(val_data, val_labels)

train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False)

# Initialize model with configuration
model = DigitCNN(config)

# Train the model - pass context to train_model
trained_model, train_losses, val_accuracies = train_model(
context, model, train_loader, val_loader, config
)

final_val_accuracy = val_accuracies[-1]

# Add metadata
context.add_output_metadata(
{
"final_val_accuracy": final_val_accuracy,
"training_epochs": len(train_losses),
"configured_epochs": config.epochs,
"model_parameters": sum(p.numel() for p in trained_model.parameters()),
"final_train_loss": train_losses[-1],
"learning_rate": config.learning_rate,
"batch_size": config.batch_size,
"optimizer": config.optimizer_type,
"early_stopping_used": config.use_early_stopping,
}
)

context.log.info(
f"Model training completed. Final validation accuracy: {final_val_accuracy:.2f}%"
)

# Save model as pickle file if requested
if config.save_model:
# Create models directory if it doesn't exist
model_dir = Path(config.model_save_dir)
model_dir.mkdir(exist_ok=True)

# Create filename with timestamp and accuracy
from datetime import datetime

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
accuracy_str = f"{final_val_accuracy:.2f}".replace(".", "p")
filename = f"{config.model_name_prefix}_{timestamp}_acc{accuracy_str}.pkl"

# Save the trained model
context.log.info(f"Saving model as {filename}")
model_store = context.resources.model_storage
model_store.save_model(trained_model, filename)

context.add_output_metadata(
{"model_name": filename, "final_accuracy": final_val_accuracy},
output_name="result",
)

return trained_model

The training asset integrates seamlessly with upstream data processing through Dagster's dependency system, ensuring training only begins after data preprocessing completes. It accepts configuration parameters that control all aspects of training behavior, enabling different strategies across development and production environments. The asset generates rich metadata including training metrics, model statistics, and configuration parameters that appear in Dagster's UI for experiment tracking and comparison.

Advanced training features and monitoring

The training system includes sophisticated features for production ML workflows: early stopping monitors validation accuracy and halts training when improvement stagnates (with configurable patience), learning rate scheduling reduces rates during plateaus for better convergence, and comprehensive logging tracks both epoch-level progress and batch-level details for debugging.

Multiple optimizer support (Adam for fast convergence, SGD with momentum for potentially better final performance) provides flexibility for different training scenarios. The system automatically handles GPU/CPU device selection and includes robust error handling for production deployment scenarios.

Model persistence uses descriptive filenames including timestamps and performance metrics, enabling easy model identification and version management. The integration with Dagster's resource system abstracts storage details, supporting both local development and cloud production environments seamlessly.

Next steps

With trained models available through our asset pipeline, the next phase focuses on comprehensive evaluation to assess model performance and determine readiness for production deployment.