UEGAN mps Implementation

focusing on changes for Apple Silicon (mps) support.

UEGAN Codebase Modifications for Apple MPS Support


TensorBoard Logger System Changes

The TensorFlow-based logging system was completely removed and replaced with PyTorch’s native TensorBoard implementation:

from torch.utils.tensorboard import SummaryWriter

class Logger(object):
    """Create a tensorboard logger to log_dir."""
    def __init__(self, log_dir):
        """Initialize summary writer."""
        self.writer = SummaryWriter(log_dir=log_dir)

    def scalar_summary(self, tag, value, step):
        """Add scalar summary."""
        self.writer.add_scalar(tag, value, step)

    def images_summary(self, tag, images, step):
        """Log a list of images."""
        self.writer.add_images(tag, images, step)

    def histo_summary(self, tag, values, step, bins='tensorflow', walltime=None, max_bins=None):
        """Log a histogram of the tensor of values."""
        self.writer.add_histogram(
            tag, values, global_step=step, bins=bins, walltime=walltime, max_bins=max_bins
        )
        self.writer.flush()  # Explicit flush to ensure data is written

This eliminates all TensorFlow dependencies, making the codebase more consistent with PyTorch and Apple Silicon compatibility.


Memory Management for MPS

On MPS devices, memory is now managed with:

torch.mps.empty_cache()
time.sleep(2)  # Added sleep to ensure memory is properly released

Within logging methods, dictionaries and image lists are also cleared to optimize memory usage.


Device Detection Modernization

Device detection logic is modernized as follows:

self.device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"

Orthogonal Initialization for MPS

Because MPS does not support QR decomposition, orthogonal initialization was updated:

elif init_type == 'orthogonal':
    if torch.backends.mps.is_available():
        weight = m.weight.to("cpu")
        torch.nn.init.orthogonal_(weight, gain=gain)
        m.weight.data.copy_(weight.to(self.device))
    else:
        torch.nn.init.orthogonal_(m.weight, gain=gain)

VGG Model Modifications

VGG19 model usage was adapted for MPS constraints:

cnn = models.vgg19(weights='IMAGENET1K_V1').features

InstanceNorm2d Dynamic Allocation

InstanceNorm2d layers are now dynamically allocated with correct channel counts:

self.IN_layers = {
    'relu1_1': nn.InstanceNorm2d(64, affine=False, track_running_stats=False),
    'relu2_1': nn.InstanceNorm2d(128, affine=False, track_running_stats=False),
    'relu3_1': nn.InstanceNorm2d(256, affine=False, track_running_stats=False),
    'relu4_1': nn.InstanceNorm2d(512, affine=False, track_running_stats=False),
    'relu5_1': nn.InstanceNorm2d(512, affine=False, track_running_stats=False),
}

Modern PyTorch Module Usage

Custom modules were replaced with PyTorch built-ins:

elif act_fun_type == 'Swish':
    return nn.SiLU()  # torch.nn.SiLU supported now

elif norm_fun_type == 'none':
    norm_fun = nn.Identity

Tensor API Modernization

.data attribute access was replaced with .detach() for modern PyTorch compatibility:

x.detach()

Learning Rate Scheduler Updates

Learning rate scheduler calls were simplified:

self.lr_scheduler_g.step()
self.lr_scheduler_d.step()

Model Saving Improvements

Model saving now includes robust error handling and path management:

save_path = Path(self.model_save_path)
save_path.mkdir(parents=True, exist_ok=True)
model_filename = f"{self.args.version}_{self.args.adv_loss_type}_{current_epoch}.pth"
model_filepath = save_path / model_filename

try:
    torch.save(checkpoint, model_filepath)
    pbar.write(f"✅ Model checkpoint saved: {model_filepath}")
except Exception as e:
    pbar.write(f"❌ Error saving model checkpoint: {e}")

DataLoader Optimizations

Device detection and iterator initialization were improved in the data loader:

self.device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
self.iter = iter(self.loader)

Batch Dimension Handling

Batch dimension handling was clarified for NIMA calculation:

image = image.unsqueeze(0)

Warning Handling Changes

Warning suppression was removed for better debugging:

# import warnings
# warnings.simplefilter("error")

Progress Reporting Enhancement

Progress is now reported more clearly using tqdm’s write method:

pbar.write((
    "Elapse:{:>.12s}, D_Step:{:>6d}/{}, G_Step:{:>6d}/{}, "
    "D_loss:{:>.4f}, G_loss:{:>.4f}, G_percep_loss:{:>.4f}, "
    "G_adv_loss:{:>.4f}, G_idt_loss:{:>.4f}"
).format(
    elapsed, step + 1, total_steps, (step + 1), total_steps,
    self.d_loss, self.g_loss, self.g_percep_loss,
    self.g_adv_loss, self.g_idt_loss
))

Random Pair Generation Improvement

Random pair generation for unsupervised learning was improved:

random.shuffle(fnames)
random.shuffle(fnames2)

These changes make the UEGAN codebase fully compatible with Apple MPS, remove TensorFlow dependencies, and modernize the implementation to align with current PyTorch best practices.

Summarized using Perplexity (Claude 3.7 Sonnet) · Retouched by Duhyeon Kim


Original Source

This implementation is based on: