UEGAN mps Implementation
focusing on changes for Apple Silicon (mps) support.
UEGAN Codebase Modifications
for Apple MPS Support
TensorBoard Logger System Changes
The TensorFlow-based logging system was completely removed and replaced with PyTorch’s native TensorBoard implementation:
from torch.utils.tensorboard import SummaryWriter
class Logger(object):
"""Create a tensorboard logger to log_dir."""
def __init__(self, log_dir):
"""Initialize summary writer."""
self.writer = SummaryWriter(log_dir=log_dir)
def scalar_summary(self, tag, value, step):
"""Add scalar summary."""
self.writer.add_scalar(tag, value, step)
def images_summary(self, tag, images, step):
"""Log a list of images."""
self.writer.add_images(tag, images, step)
def histo_summary(self, tag, values, step, bins='tensorflow', walltime=None, max_bins=None):
"""Log a histogram of the tensor of values."""
self.writer.add_histogram(
tag, values, global_step=step, bins=bins, walltime=walltime, max_bins=max_bins
)
self.writer.flush() # Explicit flush to ensure data is written
This eliminates all TensorFlow dependencies, making the codebase more consistent with PyTorch and Apple Silicon compatibility.
Memory Management for MPS
On MPS devices, memory is now managed with:
torch.mps.empty_cache()
time.sleep(2) # Added sleep to ensure memory is properly released
Within logging methods, dictionaries and image lists are also cleared to optimize memory usage.
Device Detection Modernization
Device detection logic is modernized as follows:
self.device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
Orthogonal Initialization for MPS
Because MPS does not support QR decomposition, orthogonal initialization was updated:
elif init_type == 'orthogonal':
if torch.backends.mps.is_available():
weight = m.weight.to("cpu")
torch.nn.init.orthogonal_(weight, gain=gain)
m.weight.data.copy_(weight.to(self.device))
else:
torch.nn.init.orthogonal_(m.weight, gain=gain)
VGG Model Modifications
VGG19 model usage was adapted for MPS constraints:
cnn = models.vgg19(weights='IMAGENET1K_V1').features
InstanceNorm2d Dynamic Allocation
InstanceNorm2d layers are now dynamically allocated with correct channel counts:
self.IN_layers = {
'relu1_1': nn.InstanceNorm2d(64, affine=False, track_running_stats=False),
'relu2_1': nn.InstanceNorm2d(128, affine=False, track_running_stats=False),
'relu3_1': nn.InstanceNorm2d(256, affine=False, track_running_stats=False),
'relu4_1': nn.InstanceNorm2d(512, affine=False, track_running_stats=False),
'relu5_1': nn.InstanceNorm2d(512, affine=False, track_running_stats=False),
}
Modern PyTorch Module Usage
Custom modules were replaced with PyTorch built-ins:
elif act_fun_type == 'Swish':
return nn.SiLU() # torch.nn.SiLU supported now
elif norm_fun_type == 'none':
norm_fun = nn.Identity
Tensor API Modernization
.data
attribute access was replaced with .detach()
for modern PyTorch compatibility:
x.detach()
Learning Rate Scheduler Updates
Learning rate scheduler calls were simplified:
self.lr_scheduler_g.step()
self.lr_scheduler_d.step()
Model Saving Improvements
Model saving now includes robust error handling and path management:
save_path = Path(self.model_save_path)
save_path.mkdir(parents=True, exist_ok=True)
model_filename = f"{self.args.version}_{self.args.adv_loss_type}_{current_epoch}.pth"
model_filepath = save_path / model_filename
try:
torch.save(checkpoint, model_filepath)
pbar.write(f"✅ Model checkpoint saved: {model_filepath}")
except Exception as e:
pbar.write(f"❌ Error saving model checkpoint: {e}")
DataLoader Optimizations
Device detection and iterator initialization were improved in the data loader:
self.device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
self.iter = iter(self.loader)
Batch Dimension Handling
Batch dimension handling was clarified for NIMA calculation:
image = image.unsqueeze(0)
Warning Handling Changes
Warning suppression was removed for better debugging:
# import warnings
# warnings.simplefilter("error")
Progress Reporting Enhancement
Progress is now reported more clearly using tqdm’s write
method:
pbar.write((
"Elapse:{:>.12s}, D_Step:{:>6d}/{}, G_Step:{:>6d}/{}, "
"D_loss:{:>.4f}, G_loss:{:>.4f}, G_percep_loss:{:>.4f}, "
"G_adv_loss:{:>.4f}, G_idt_loss:{:>.4f}"
).format(
elapsed, step + 1, total_steps, (step + 1), total_steps,
self.d_loss, self.g_loss, self.g_percep_loss,
self.g_adv_loss, self.g_idt_loss
))
Random Pair Generation Improvement
Random pair generation for unsupervised learning was improved:
random.shuffle(fnames)
random.shuffle(fnames2)
These changes make the UEGAN codebase fully compatible with Apple MPS, remove TensorFlow dependencies, and modernize the implementation to align with current PyTorch best practices.
Original Source
This implementation is based on:
- Paper: Unsupervised Image Enhancement Using GANs (Replace with the correct arXiv link if available)
- Code: eezkni/UEGAN
We acknowledge the original authors for their foundational work.