<< nn-experiments

MLP-Mixer re-invented for auto-encoding

Recently read this paper "MLP-Mixer: An all-MLP Architecture for Vision" 2105.01601, which i found quite inspiring but also left me with some regret of not having come up with it myself.

Anyways, after some weeks, i set out to program a network based on the ideas but without looking at the paper for reference. At least before i have a performing implementation. Of course, because i like auto-encoders so much, i used it for auto-encoding.

Setup is as follows:

First split the input image into equal sized patches. I'm using pytorch's unfold method for this. For the MNIST digits dataset this yield 4x4 7x7 patches for each 28x28 image.

Then they are flattened to 16 1-D vectors of size 49. Those vectors are moved to the batch-dimension (the first one) of the vector that the linear layers ought to process. For example, when passing in a batch-size of 64 (64 images at once):

input:                      (64, 1, 28, 28)
patches:                    (64, 4, 4, 1, 7, 7)
flatten:                    (64, 16, 49)
move to batch dimension:    (1024, 49)

Now these 1024 49-D vectors can be processed by a stack of linear layers (MLP). Each vector is processed by the same layer and since it's only one little patch, the linear layer can be very small. However, there is no information shared between those patches. Each patch is processed individually. To apply the "mixing" strategy, i first tried a 1-D convolutional layer

input:                      (1024, 49)
reverse batch-flattening:   (64, 16, 49)
apply Conv1d(16, 16, 1):    (64, 16, 49)
move to batch dimension:    (1024, 49)
process further with MLP:   (1024, 49)

The convolution is applied channel-wise and is able to mix information from the different patches.

Note that the 49 is the depth of each layer and does not have to stay like this. Each linear layer can make it larger or smaller. In the following experiment, i resized the 49 to 64 in the first layer and kept it like this until the final linear layer, which resizes it to 16 which yields a compression ratio of 49 (28x28 -> 49) for the auto-encoder.

Also, whenever a layer has the same input and output dimension, a residual connection is added (after the activation function).

The whole auto-encoder is built symmetrically, so the decoder stack is the reverse of the encoder stack. Here is one example in pytorch speak:

MixerMLP(
  (patchify): Patchify(patch_size=7)
  (encoder): ModuleList(
    (0): MLPLayer(
      residual=False
      (module): Linear(in_features=49, out_features=64, bias=True)
      (act): GELU(approximate='none')
    )
    (1): MLPLayer(
      residual=True
      (module): Linear(in_features=64, out_features=64, bias=True)
      (act): GELU(approximate='none')
    )
    (2): MLPMixerLayer(
      residual=True
      (module): Conv1d(16, 16, kernel_size=(1,), stride=(1,))
    )
    (3): MLPLayer(
      residual=True
      (module): Linear(in_features=64, out_features=64, bias=True)
      (act): GELU(approximate='none')
    )
    (4): MLPMixerLayer(
      residual=True
      (module): Conv1d(16, 16, kernel_size=(1,), stride=(1,))
    )
    (5): MLPLayer(
      residual=False
      (module): Linear(in_features=64, out_features=16, bias=True)
      (act): GELU(approximate='none')
    )
    (6): MLPMixerLayer(
      residual=True
      (module): Conv1d(16, 16, kernel_size=(1,), stride=(1,))
    )
    (7): MLPLayer(
      residual=False
      (module): Linear(in_features=256, out_features=16, bias=True)
    )
  )
  (decoder): ModuleList(
    (0): MLPLayer(
      residual=False
      (module): Linear(in_features=16, out_features=256, bias=True)
    )
    (1): MLPMixerLayer(
      residual=True
      (module): Conv1d(16, 16, kernel_size=(1,), stride=(1,))
      (act): GELU(approximate='none')
    )
    (2): MLPLayer(
      residual=False
      (module): Linear(in_features=16, out_features=64, bias=True)
      (act): GELU(approximate='none')
    )
    (3): MLPMixerLayer(
      residual=True
      (module): Conv1d(16, 16, kernel_size=(1,), stride=(1,))
      (act): GELU(approximate='none')
    )
    (4): MLPLayer(
      residual=True
      (module): Linear(in_features=64, out_features=64, bias=True)
      (act): GELU(approximate='none')
    )
    (5): MLPMixerLayer(
      residual=True
      (module): Conv1d(16, 16, kernel_size=(1,), stride=(1,))
      (act): GELU(approximate='none')
    )
    (6): MLPLayer(
      residual=True
      (module): Linear(in_features=64, out_features=64, bias=True)
      (act): GELU(approximate='none')
    )
    (7): MLPLayer(
      residual=False
      (module): Linear(in_features=64, out_features=49, bias=True)
    )
  )
  (unpatchify): Unpatchify(patch_size=7)
)

Note that, in the encoder, layer 5 compresses each patch from 64 to 16 dimensions, then there is another mixing layer, and layer 7 processes all 16 patches at once and outputs a final 16-dim vector (256 -> 16).

As with other MLPs, this auto-encoder is input-size-dependent. The network is constructed for a fixed image size and can not be used for other sizes.

In the following table bs = batch size, lr = learnrate, ds is the dataset, ps is the patch size, mix means at which layer a mixing-layer is added (1-indexed). Optimizer is AdamW and loss function is mean-squared error. Scheduler is cosine-annealing with warmup.

Experiment file: experiments/ae/mixer/mixer02.yml @ 18e32d81

bs lr ds ps ch mix act validation loss
(1,000,000 steps)
model params train time (minutes) throughput
64 0.003 mnist 7 64,64,64,16 gelu 0.011366 33,617 1.03 16,126/s
64 0.003 mnist 7 64,64,64,16 4 gelu 0.011248 34,161 1.25 13,335/s
64 0.003 mnist 7 64,64,64,16 3 gelu 0.010779 34,161 1.21 13,779/s
64 0.003 mnist 7 64,64,64,16 2 gelu 0.010677 34,161 1.2 13,865/s
64 0.003 mnist 7 64,64,64,16 1 gelu 0.010659 34,161 1.21 13,748/s
64 0.003 mnist 7 64,64,64,16 1,2 gelu 0.010418 34,705 1.29 12,883/s
64 0.003 mnist 7 64,64,64,16 1,2,3 gelu 0.010398 35,249 1.46 11,384/s
64 0.003 mnist 7 64,64,64,16 1,2,3,4 gelu 0.010208 35,793 1.61 10,323/s

Validation loss performance is not actually super good and the mixing only helps a little. But it's not so bad either for such a small network.

Linear layer as mixer

So let's try a linear layer for mixing between patches.

input:                      (1024, 49)
reverse batch-flattening:   (64, 16, 49)
flatten 2nd dimension:      (64, 784)
apply Linear(784, 784):     (64, 784)
unflatten 2nd dimension:    (64, 16, 49)
and move to batch dim:      (1024, 49)
process further with MLP:   (1024, 49)

Obviously, the linear mixing layer increases the network size considerably. The benefit is, that each pixel in each patch can be compared/processed against each other.

I guess, this is kind of what the MLP-Mixer paper suggests, only i can not confidentially read their jax code to confirm.

bs lr ds ps ch mix act validation loss
(1,000,000 steps)
model params train time (minutes) throughput
64 0.003 mnist 7 64,64,64,16 4 gelu 0.0105164 165,201 1.16 14,418/s
64 0.003 mnist 7 64,64,64,16 1 gelu 0.0078778 2,132,817 1.17 14,292/s
64 0.003 mnist 7 64,64,64,16 2 gelu 0.0072450 2,132,817 1.16 14,345/s
64 0.003 mnist 7 64,64,64,16 3 gelu 0.0071680 2,132,817 1.21 13,793/s
64 0.003 mnist 7 64,64,64,16 3,4 gelu 0.0069653 2,264,401 1.41 11,778/s
64 0.003 mnist 7 64,64,64,16 1,2 gelu 0.0068538 4,232,017 1.58 10,559/s
64 0.003 mnist 7 64,64,64,16 1,2,3,4 gelu 0.0066289 6,462,801 2.06 8,083/s
64 0.003 mnist 7 64,64,64,16 1,2,3 gelu 0.0065175 6,331,217 1.87 8,894/s
64 0.003 mnist 7 64,64,64,16 2,3,4 gelu 0.0065056 4,363,601 1.62 10,303/s

Here's a plot of the loss curves for both mixing methods:

loss curves

Performance is much better but the number of model parameters is much larger as well.

Comparison with baseline CNN

How well does it perform against a baseline CNN with a final linear layer?

EncoderDecoder(
  (encoder): EncoderConv2d(
    (convolution): Conv2dBlock(
      (layers): Sequential(
        (0): Conv2d(1, 24, kernel_size=(5, 5), stride=(1, 1))
        (1): ReLU()
        (2): Conv2d(24, 32, kernel_size=(5, 5), stride=(1, 1))
        (3): ReLU()
        (4): Conv2d(32, 48, kernel_size=(5, 5), stride=(1, 1))
        (5): ReLU()
      )
    )
    (linear): Linear(in_features=12288, out_features=16, bias=True)
  )
  (decoder): Sequential(
    (0): Linear(in_features=16, out_features=12288, bias=True)
    (1): Reshape(shape=(48, 16, 16))
    (2): Conv2dBlock(
      (layers): Sequential(
        (0): ConvTranspose2d(48, 32, kernel_size=(5, 5), stride=(1, 1))
        (1): ReLU()
        (2): ConvTranspose2d(32, 24, kernel_size=(5, 5), stride=(1, 1))
        (3): ReLU()
        (4): ConvTranspose2d(24, 1, kernel_size=(5, 5), stride=(1, 1))
      )
    )
  )
)

ks is kernel size

ch ks validation loss
(1,000,000 steps)
model params train time (minutes) throughput
24,32,48 9 0.00780634 402,657 1.87 8,914/s
24,32,48 7 0.00710207 386,721 1.44 11,557/s
24,32,48 3 0.00718373 808,737 2.11 7,908/s
24,32,48 5 0.00655708 522,081 1.93 8,643/s

Seems to be equally good at first glance. However, training both networks a bit longer reveals that the CNN is able to squeeze out some more performance:

loss curves

After 6 million steps (or 100 epochs), the best performing Mixer-MLP from above archives a validation loss of 0.0063 vs. 0.0055 for the CNN model.

The Mixer-MLP is 4 times larger than the CNN model (because of the 3 linear mixing stages) and we can see from the loss curves that it actually overfitted the training set. It's validation performance starts to get worse after 5 million steps.

However, since it seems to behave quite differently to the baseline CNN, it's certainly a nice new tool to try things on.

It would be interesting to see if the mixing stages implemented with a convolution, can--somehow--reach the performance of the fully connected linear mixers. That would make it very small and powerful.

Tweaking the convolutional mixing stage

First of all, we can use a larger kernel size than 1. That would allow the mixing stage to correlate nearby pixels of individual patches. Since it's a 1-D convolution that would still only correlate pixels left and right and not the ones above and below.

Secondly we might use 2 convolutions. One that mixes the patches and one that mixes the pixels. Following the example above, that would look like:

input:                      (1024, 49)
reverse batch-flattening:   (64, 16, 49)
apply Conv1d(16, 16):       (64, 16, 49)
transpose vector:           (64, 49, 16)
apply Conv1d(49, 49):       (64, 49, 16)
transpose vector:           (64, 16, 49)
move to batch dimension:    (1024, 49)
process further with MLP:   (1024, 49)

And thirdly, we could actually do a 2-dimensional convolutional mixing:

input:                      (1024, 49)
reverse batch-flattening:   (64, 16, 49)
restore 2-D patches:        (64, 16, 7, 7) 
apply Conv2d(16, 16):       (64, 16, 7, 7)
flatten the patches:        (64, 16, 49)
move to batch dimension:    (1024, 49)
process further with MLP:   (1024, 49)

However, since the final dimension is not actually fixed at 49 (7x7) but can be anything, it does not really represent the 2-dimensional patch content but rather just the computational breadth of the MLP stack at each layer. We might still treat it as spatial, and let the network figure out some good way to use it.

Fourthly, if that is even a word, we can combine the two-consecutive-convolution method with the 2-D convolution which puts another constraint on the number of patches resulting from the input image: It also has to be a square number. It gets a bit complicated then:

input:                      (1024, 49)
reverse batch-flattening:   (64, 16, 49)
restore 2-D patches:        (64, 16, 7, 7) 
apply Conv2d(16, 16, 1):    (64, 16, 7, 7)
flatten patches:            (64, 16, 49)
transpose dimensions:       (64, 49, 16)
make "breadth" 2-dim:       (64, 49, 4, 4)
apply Conv2d(49, 49, 1):    (64, 49, 4, 4)
flatten "breadth" dim:      (64, 49, 16)
transpose dimensions:       (64, 16, 49)
move to batch dimension:    (1024, 49)
process further with MLP:   (1024, 49)

Okay, lets test all the methods with different kernel sizes from 1 to 13. The padding of the convolution is adjusted ((ks - 1) // 2) so the convolutions always produce the same output size.

In below table mixtype is either:

bs lr ds ps ch mix mixtype ks act validation loss
(1,000,000 steps)
model params train time (minutes) throughput
64 0.003 mnist 7 64,64,64,16 2,3,4 cnn 1 gelu 0.0105176 35,249 1.53 10,866/s
64 0.003 mnist 7 64,64,64,16 2,3,4 cnn2d 1 gelu 0.0103852 35,249 1.45 11,517/s
64 0.003 mnist 7 64,64,64,16 2,3,4 cnn 3 gelu 0.0101478 38,321 1.46 11,447/s
64 0.003 mnist 7 64,64,64,16 2,3,4 cnn 5 gelu 0.0097217 41,393 1.50 11,076/s
64 0.003 mnist 7 64,64,64,16 2,3,4 cnn2df 1 gelu 0.0096381 52,433 2.09 7,989/s
64 0.003 mnist 7 64,64,64,16 2,3,4 cnnf 1 gelu 0.0095420 52,433 1.90 8,772/s
64 0.003 mnist 7 64,64,64,16 2,3,4 cnn 7 gelu 0.0095143 44,465 1.42 11,726/s
64 0.003 mnist 7 64,64,64,16 2,3,4 cnn 9 gelu 0.0094267 47,537 1.45 11,489/s
64 0.003 mnist 7 64,64,64,16 2,3,4 cnn2d 3 gelu 0.0093458 47,537 1.65 10,115/s
64 0.003 mnist 7 64,64,64,16 2,3,4 cnn 11 gelu 0.0092456 50,609 1.47 11,327/s
64 0.003 mnist 7 64,64,64,16 2,3,4 cnn 13 gelu 0.0092080 53,681 1.50 11,107/s
64 0.003 mnist 7 64,64,64,16 2,3,4 cnn2d 5 gelu 0.0086525 72,113 1.79 9,292/s
64 0.003 mnist 7 64,64,64,16 2,3,4 cnn2d 7 gelu 0.0083488 108,977 1.87 8,911/s
64 0.003 mnist 7 64,64,64,16 2,3,4 cnn2d 9 gelu 0.0080035 158,129 2.05 8,120/s
64 0.003 mnist 7 64,64,64,16 2,3,4 cnnf 3 gelu 0.0079969 89,297 1.98 8,404/s
64 0.003 mnist 7 64,64,64,16 2,3,4 cnn2d 13 gelu 0.0077707 293,297 2.02 8,240/s
64 0.003 mnist 7 64,64,64,16 2,3,4 cnn2d 11 gelu 0.0077622 219,569 1.88 8,865/s
64 0.003 mnist 7 64,64,64,16 2,3,4 cnnf 5 gelu 0.0075978 126,161 2.31 7,211/s
64 0.003 mnist 7 64,64,64,16 2,3,4 cnnf 7 gelu 0.0073687 163,025 2.47 6,740/s
64 0.003 mnist 7 64,64,64,16 2,3,4 cnn2df 3 gelu 0.0071530 199,889 2.12 7,855/s
64 0.003 mnist 7 64,64,64,16 2,3,4 cnnf 9 gelu 0.0071354 199,889 2.62 6,351/s
64 0.003 mnist 7 64,64,64,16 2,3,4 cnnf 13 gelu 0.0070266 273,617 2.87 5,815/s
64 0.003 mnist 7 64,64,64,16 2,3,4 cnnf 11 gelu 0.0070069 236,753 2.52 6,625/s
64 0.003 mnist 7 64,64,64,16 2,3,4 cnn2df 5 gelu 0.0068095 494,801 2.35 7,086/s
64 0.003 mnist 7 64,64,64,16 2,3,4 cnn2df 13 gelu 0.0067350 3,149,009 7.85 2,123/s
64 0.003 mnist 7 64,64,64,16 2,3,4 cnn2df 11 gelu 0.0066938 2,264,273 6.05 2,755/s
64 0.003 mnist 7 64,64,64,16 2,3,4 cnn2df 9 gelu 0.0066766 1,526,993 4.45 3,745/s
64 0.003 mnist 7 64,64,64,16 2,3,4 cnn2df 7 gelu 0.0066100 937,169 2.43 6,872/s

Okay, first of all, just increasing the kernel size is not so helpful. As stated above, it only connects pixels left and right but not the once above and below. The cnn mixer archives a validation loss between 0.01 and 0.0092 for kernel sizes 1 to 13.

The second option: Two convolutions for mixing (cnnf) work pretty good in comparison. The losses range between 0.0095 and 0.007 for the different kernel sizes.

Third option (cnn2d): The spatial 2-D convolutions work okayish. They range between 0.01 (similar to first option with kernel size 1) and 0.0078. It should be noted, that the huge kernel sizes do not make much sense in this example. The spatial patches are either 7x7 or 4x4. It would make a difference for larger input sizes, though.

Forth option (cnn2df): Two spatial convolutions work best in this example. The losses range between 0.0096 (similar to cnnf) and 0.0066. Almost reaching the performance of the network with fully connected linear mixers, while being much smaller in number of parameters (at least the kernel size 7 one).

So, lets look what happens when processing something larger than the MNIST digit set.

Increasing resolution

In ML-papers, when larger images are required, it's almost always the ImageNet dataset that is used. Well, i sent them a request but got never got an answer. If it's just about auto-encoding, one can just crawl the web for a hundredthousand random images but this diminishes reproducibility.

So i chose the Unsplash Lite dataset, although it has some minor issues at the moment.

These are about 25,000 images, which i scaled down to 160 pixels for the longest edge, but at least 96 pixels for the smallest edge and randomly cropped 96x96 patches for training. The first 100 images are used for validation and they are center-cropped to 96x96.

To keep the auto-encoder compression ratio of 49 for comparison, the latent code has a size of 564 (3*96*96 / 564 = 49.0213).

First, i tried different patch sizes with the cnn and cnnf mixers.

bs lr ds ps ch mix mixt ks act validation loss
(1,000,000 steps)
model params train time (minutes) throughput
64 0.0003 unsplash 32 64,64,64,564 2,3,4 cnn 3 gelu 0.0073214 6,218,692 10.03 1,660/s
64 0.0003 unsplash 32 64,64,64,564 2,3,4 cnnf 3 gelu 0.0057957 8,177,804 10.27 1,622/s
64 0.0003 unsplash 16 64,64,64,564 2,3,4 cnn 3 gelu 0.0034742 23,135,920 11.07 1,505/s
64 0.0003 unsplash 16 64,64,64,564 2,3,4 cnnf 3 gelu 0.0033309 25,095,032 11.63 1,432/s
64 0.0003 unsplash 8 64,64,64,564 2,3,4 cnn 3 gelu 0.0032881 92,181,832 19.27 864/s
64 0.0003 unsplash 8 64,64,64,564 2,3,4 cnnf 3 gelu 0.0031301 94,140,944 32.87 507/s

Whoa, what happened here? First of all, the number of parameters is huge in comparison, and then the validation loss beats every previous model run on the MNIST digits dataset. Let's look closer at both phenomena.

The increased parameter count can be blamed to mostly a single layer. Look at the encoder of the patch-size-8 model:

(patchify): Patchify(patch_size=8)
(encoder): ModuleList(
    (0): MLPLayer(
      residual=False
      (module): Linear(in_features=192, out_features=64, bias=True)
      (act): GELU(approximate='none')
    )
    (1): MLPLayer(
      residual=True
      (module): Linear(in_features=64, out_features=64, bias=True)
      (act): GELU(approximate='none')
    )
    (2): MLPMixerLayer(
      residual=True
      (module): Conv1d(144, 144, kernel_size=(3,), stride=(1,), padding=(1,))
      (act): GELU(approximate='none')
    )
    (3): MLPLayer(
      residual=True
      (module): Linear(in_features=64, out_features=64, bias=True)
      (act): GELU(approximate='none')
    )
    (4): MLPMixerLayer(
      residual=True
      (module): Conv1d(144, 144, kernel_size=(3,), stride=(1,), padding=(1,))
      (act): GELU(approximate='none')
    )
    (5): MLPLayer(
      residual=False
      (module): Linear(in_features=64, out_features=564, bias=True)
      (act): GELU(approximate='none')
    )
    (6): MLPMixerLayer(
      residual=True
      (module): Conv1d(144, 144, kernel_size=(3,), stride=(1,), padding=(1,))
      (act): GELU(approximate='none')
    )
    (7): MLPLayer(
      residual=False
      (module): Linear(in_features=81216, out_features=564, bias=True)
    )
)

The 96x96 images are sliced into 144 8x8 patches. The convolutions require 144 * 144 * 3 * 3 = 186,624 parameters each (+ the bias) so their size is acceptable.

The fifth layer expands to hidden dimension of 564 and the final mixdown-to-latent-code layer (lets call it head) gets a 144 * 564 vector as input and mixes it down to 564: 81216 * 564 = 45,805,824! The decoder has the same thing in reverse. That's completely exagerated for the head of a otherwise tiny network.

Quite obviously, there is no need to expand the hidden dimensions to 564 before the head. That was just how i wrote the code without thinking too much when the networks were small. That needs to be fixed.

So what about the validation loss of around 0.003? Turns out that photos generally yield a lower reconstruction error compared to the high-contrast black and white MNIST digits. It does not look too great, though:

reconstructed images

(original on top, reconstruction below)

To fix the problem with the large final layer, there is now one additonal channel number in the model configuration, which defines the breath of the final layer before the head. To keep the networks small, i used 32 (or 36 to make it square for the 2-D convolutions). Here a are a few other setups:

bs lr ds ps ch mix mixt ks act validation loss
(1,000,000 steps)
model params train time (minutes) throughput
64 0.0003 unsplash 8 144,144,144,36,564 2,3,4 cnn2df 7 gelu 0.0544467 16,292,160 37.2 448/s
64 0.0001 unsplash 8 144x9...,36,564 2,4,6,8 cnn2df 7 gelu 0.0036219 22,512,888 60.97 273/s
64 0.0003 unsplash 8 64,64,64,32,564 2,3,4 cnnf 3 gelu 0.0033642 5,678,388 10.78 1,546/s
64 0.0003 unsplash 8 64,64,64,32,564 2,3,4 cnnf 13 gelu 0.0033175 7,106,868 12.61 1,322/s
64 0.0003 unsplash 8 64,64,64,36,564 2,3,4 cnn2df 7 gelu 0.0032664 12,926,880 23.34 713/s

The networks are still pretty big, which feels unjustified when trusting my gut feelings. There is also no increase in performance using the 2-D convolutions or making the network deeper (the 144x9... means an mlp with 9 layers of 144 channels each). It actually starts to feel like a waste of energy, literally.

Just to see what else is possible i compare against some recent "baseline" CNN-only model that uses no MLP layers but instead decreases the spatial resolution via torch's PixelUnshuffle method. This move parts of the spatial pixels into the channels dimension, which can be gradually decreased to match a certain compression ratio. In this case, the network transforms the 3x96x96 to 1x24x24, which corresponds to a ratio of 48.

The "script" explains the design of the network: ch=X means: set the size of the color/channel-dimension to X via a 2-D convolution and activation function. ch*X or ch/X multiplies or divides the number of channels via such a convolution layer. If the channel number is unchanged (e.g. ch*1) the layer also has a residual skip connection. downX is the PixelUnshuffle transform. E.g. down4 means, transform C x H x W to C*4*4 x H/4 x W/4. And the ch=32|down4 part means transform input image 3x96x96 to 32x96x96 and the unshuffle to 512x24x24.

script ks act validation loss
(1,000,000 steps)
model params train time (minutes) throughput
ch=32|down4|ch/2|ch/2|ch1|ch1|ch=1 3 relu6 0.00282578 3,248,612 36.04 462/s
scl=3|ch=32|down4|ch/2|ch/2|ch1|\ch1|ch=1 3 relu6 0.00232538 15,446,750 179.57 92/s

The scl=3 in the second network script sets the number of skip-convolution-layers for all the following convolutions. This idea is borrowed from Latent Space Characterization of Autoencoder Variants (2412.04755) and is called skip-connected triple convolution* there. In our case it means, for each convolution, actually do three of them + add a residual skip connection.

Performance is certainly better than the Mixer-MLP's above. Still, it feels much like a waste of energy and time to pursue, e.g., twice the performance with similar models.

reconstructed images

I rather stop the high-res (well 96² is not high resolution by any modern standards but you know what i mean) experiments for now and rather try some new tricks in lower resolution for faster experiment iterations.