Delve deeper into Fourier Neural Operator: Explanation and use case for simple turbulence…



This content originally appeared on Level Up Coding – Medium and was authored by Muhammad Ryan

Delve deeper into Fourier Neural Operator: Explanation and use case for a simple turbulence prediction

Recently, I’ve been quite interested in exploring the prospect of Neural Networks to mimic how numerical models work, but without their limitations. One of the limitations is the requirement to calculate all related parameters even if we only care about just 1 parameter. For example, in the Navier-Stokes equation, if we want to predict the pressure for a few timesteps ahead, then we also need to calculate all other parameters such as density, U, V, W, etc. Even for a small area, we will need quite a big resource to compute that. Another constraint that makes numerical models need such a huge computational power is the requirement to match spatial and time resolution to avoid instability in the computation. Therefore, there will be a trade-off between how far in the future we want to predict per timestep vs how detailed the prediction is. This is called the Courant-Friedrichs-Lewy (CFL) condition. Hence, I stumbled upon the Fourier Neural Operator (FNO) when doing this small research. So, let’s first familiarise ourselves with the model!

Basic FNO

The basic flowchart of FNO looks like the chart below. When used to process imagery data, the dimensions of the data input used are 5 dimensions, B (batch) x T (timestep), C (Channel), H (height), W (Width), similar to the output. However, usually we will only try to predict 1 timestep; thus, the output data shape will be only 4 dimensions (B, C, H, W), and if we only process 1 parameter, then it reduces to (B, H, W). It only applied to the input data shape.

The feature extractor is a layer to make sure the shape of the input of the FNO block fits with the block. Usually, it will change the total depth (total timestep times total channel), same as the “hidden depth” hyperparameter value. The details about this hyperparameter will be clear when I show the code.

The next block is FNO. The detailed process of the block is displayed in the red box in the image above. After the input is processed by the feature extractor (or by the previous FNO block), it is processed by an FNO block. We apply the Discrete Fourier Transform (DFT) to the input space to bring it from physical to frequency space. However, the most efficient algorithm to calculate the DFT process is called Fast Fourier Transform (FFT). In other literature, they only refer to the FFT for convenience. I’ve already written about the Fourier transform in detail in my other article here.

Insight to the Fourier Transform and The Simple Implementation of It

The next step is filtering, to filter out noise from the input. As we know, we can use the FFT for filtering data by removing the high-frequency part. However, in FNO, we will make a slight difference in things, but the essence remains the same.

Same as the usage of FNO in noise filtering, after the filtering step, we redo the FFT process using the inverse FFT. By doing this, we restore the original input except for the noise part.

Now, to make all of this preaching, here is the code for the FNO. I adapted this code from this GitHub repository by Guo. Here is the corresponding article of the code. I will only adapt their FNO model code here, not take a step forward and experiment with the “multi-grid” part. Here, the code is using convolutional layers because we want to handle an imagery dataset.

For FNO’s processing block, you can focus your attention on this class.

class SpectralConv2d_fast(nn.Module):
def __init__(self, in_channels, out_channels, modes1, modes2):
super(SpectralConv2d_fast, self).__init__()

self.in_channels = in_channels
self.out_channels = out_channels
self.modes1 = modes1 # Number of Fourier modes to multiply, at most floor(N/2) + 1
self.modes2 = modes2

self.scale = (1 / (in_channels * out_channels))
self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat))
self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat))

def forward(self, x):
batchsize = x.shape[0]
x_ft = torch.fft.rfft2(x, norm='ortho')

# Multiply relevant Fourier modes
out_ft = torch.zeros(batchsize, self.out_channels, x.size(-2), x.size(-1) // 2 + 1, dtype=torch.cfloat, device=x.device)
out_ft[:, :, :self.modes1, :self.modes2] = \
compl_mul2d(x_ft[:, :, :self.modes1, :self.modes2], self.weights1)
out_ft[:, :, -self.modes1:, :self.modes2] = \
compl_mul2d(x_ft[:, :, -self.modes1:, :self.modes2], self.weights2)

# Return to physical space
x = torch.fft.irfft2(out_ft, s=(x.size(-2), x.size(-1)), norm='ortho')
return x

The FFT process is this 1 line.

x_ft = torch.fft.rfft2(x, norm='ortho')

Then the filtering part is here.

# Multiply relevant Fourier modes
out_ft = torch.zeros(batchsize, self.out_channels, x.size(-2), x.size(-1) // 2 + 1, dtype=torch.cfloat, device=x.device)
out_ft[:, :, :self.modes1, :self.modes2] = \
compl_mul2d(x_ft[:, :, :self.modes1, :self.modes2], self.weights1)
out_ft[:, :, -self.modes1:, :self.modes2] = \
compl_mul2d(x_ft[:, :, -self.modes1:, :self.modes2], self.weights2)

Now, this is the interesting part. I’ve said before that the filtering in FNO is not just merely zeroing the high frequencies, but also adjusting the importance of the low frequencies by multiplying them with the weight of the model. Suppose in the training phase, that certain frequency doesn’t bring impact for the prediction, its weight value will be low or even 0.

As you can see, we have done 2 multiplications here. The first multiplication is accessing, out_ft[:, :, :self.modes1, :self.modes2] which means we only bother with frequencies below the “mode” we set earlier as the hyperparameter of the FNO. For the 2nd multiplication, we accessed out_ft[:, :, -self.modes1:, :self.modes2] . As we know, the result of the FFT frequency domain is symmetry; the lower frequency will have its high or negative frequency conjugate. Actually, FFT can only map the limit frequency up to 1/2 of the sample rate of the dataset. To make it clear, here is an example of what happens if we apply the FFT to a signal of a sum of 10 and 30 Hz.

import numpy as np
import matplotlib.pyplot as plt

def plot_fft_symmetry(signal, sample_rate, title="FFT Symmetry of a Real Signal"):
N = len(signal) # Number of samples
T = 1.0 / sample_rate # Sample spacing


yf = np.fft.fft(signal)

# Generate frequency bins
xf = np.fft.fftfreq(N, T)


sorted_indices = np.argsort(xf)
xf_sorted = xf[sorted_indices]
yf_sorted = yf[sorted_indices]

# --- Plotting ---
plt.figure(figsize=(14, 10))
plt.suptitle(title, fontsize=16)

# Subplot 1: Original Signal
plt.subplot(2, 1, 1)
plt.plot(np.arange(N) * T, signal)
plt.title('Original Real-Valued Signal')
plt.xlabel('Time (s)')
plt.ylabel('Amplitude')
plt.grid(True)

# Subplot 2: Magnitude Spectrum
plt.subplot(2, 1, 2)
plt.plot(xf_sorted, np.abs(yf_sorted))
plt.title('Magnitude Spectrum: $|F(k)|$')
plt.xlabel('Frequency (Hz)')
plt.ylabel('Magnitude')
plt.grid(True)
# Highlight symmetry: visually inspect |F[k]| vs |F[-k]|
plt.axvline(0, color='gray', linestyle='--', linewidth=0.8)

plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout to make space for suptitle
plt.show()



sample_rate_2 = 200 # Hz
duration_2 = 1 # seconds
frequency_a = 10 # Hz
frequency_b = 30 # Hz
t_2 = np.linspace(0, duration_2, int(sample_rate_2 * duration_2), endpoint=False)
signal_2 = 0.7 * np.sin(2 * np.pi * frequency_a * t_2) + 0.3 * np.cos(2 * np.pi * frequency_b * t_2)

plot_fft_symmetry(signal_2, sample_rate_2, "FFT Symmetry: Composite Signal")

And the result is here. As you can see, we have 4 peaks here for 10 and 30 Hz and their counterpart on the negative side.

Now it is clear the reason why we accessed :self.modes1 and -self.modes1: for the first and second parts of multiplication. But, why for the 2nd mode, is it always the same :self.modes2 ? The reason is that when we apply a double FFT to an imagery dataset, we only receive the positive frequencies part; thus, the last dimension of the frequency domain matrix is reduced to half. Hence, we only need to refer :self.modes2 to both multiplications. In detail, about multiple FFTs can be read in my previous post.

Fourier N-dimension for dummy, I Guess…

The third part is to restore the filtered dataset and return it using this line.

# Return to physical space
x = torch.fft.irfft2(out_ft, s=(x.size(-2), x.size(-1)), norm='ortho')
return x

This FNO core building block is the basic block for the SimpleBlock2d() . The general workflow is handled here. Let’s focus on the forward code.

def forward(self, x):
x = self.p(x)
x = x.permute(0, 4, 1, 2, 3)
x = x.view(x.shape[0], -1, x.shape[3], x.shape[4])

x1 = self.norm(self.conv0(self.norm(x)))
x1 = self.mlp0(x1)
x2 = self.w0(x)
x=x1+x2+x
x = F.gelu(x)

x1 = self.norm(self.conv1(self.norm(x)))
x1 = self.mlp1(x1)
x2 = self.w1(x)
x = x1 + x2 + x
x = F.gelu(x)

x1 = self.norm(self.conv2(self.norm(x)))
x1 = self.mlp2(x1)
x2 = self.w2(x)
x = x1 + x2 + x
x = F.gelu(x)

x1 = self.norm(self.conv3(self.norm(x)))
x1 = self.mlp3(x1)
x2 = self.w3(x)
x = x1 + x2 + x
x = F.gelu(x)

x = self.q(x)
x = x.permute(0, 2, 3, 1)

return x

The first step is the feature extractor part, which is the first few lines in forward(). Like I’ve said before, its objective is to match the shape of the input with FNO’s block processing. The last step in this feature extractor is to reshape the batch matrix, combining T and C dimensions.

x = self.p(x)
x = x.permute(0, 4, 1, 2, 3)
x = x.view(x.shape[0], -1, x.shape[3], x.shape[4])

After that, the last step before x = self.q(x) is the FNO’s blocks processing. As you can see, they consist of some repeated code.

x1 = self.norm(self.conv1(self.norm(x)))
x1 = self.mlp1(x1)
x2 = self.w1(x)
x = x1 + x2 + x
x = F.gelu(x)

The computation flow is like this. Here we adapt a skip connection because the layer in this model is quite deep. Norm here is a batch normalisation process, and Gelu is an activation function similar to ReLU but smoother.

The process occurs in Net2d() which the main class for the model consists of running 1 instance of SimpleBlock2d() . In my opinion, I think it is better to add an FNO processing block in SimpleBlock2d() than add another instance of SimpleBlock2d() in Net2d() if you want to increase the processing layer of the model for more complex cases.

Turbulence Prediction Case

Now all the lines in the FNO model are clear. Let’s see how this model performs for a simple case. The test case I’ve used in this video is turbulence prediction. I got the dataset from this repo. See the “Downloading our Data and Model” section. I’ve downloaded the data using rsync command.

rsync -P rsync://m1734798.001@dataserv.ub.tum.de/m1734798.001/128_tra_small* ./data

From there, you will get several parameters of turbulence such as density, pressure, and velocity. The total timestep in the dataset is 1000.

In this test case, I’ve tried to predict a density using only a 10-sequence of previous densities. Thus, I deleted other parameters.

To load the dataset, I’ve created this dataset class

#setting
dataset_dir = 'dataset/density_training'
# Sequence lengths
SEQ_LEN_FEATURE = 10 # Number of past time-nodes as input for prediction
SEQ_LEN_ROLLOUT = 25 # Number of future timesteps to predict in an unrolling manner

class DensityTurbulenceDataset(Dataset):
def __init__(self, dataset_dir, seq_len_total):
self.file_list = sorted(glob(os.path.join(dataset_dir, "*.npz")))
self.seq_len_total = seq_len_total # SEQ_LEN_FEATURE + SEQ_LEN_ROLLOUT

self.data_cache = []
print(f"Loading dataset from {dataset_dir} into memory...")
for f_path in tqdm(self.file_list, desc=f"Loading {os.path.basename(dataset_dir)}"):
self.data_cache.append(np.load(f_path)['arr_0'])
print(f"Dataset from {dataset_dir} loaded. Total files: {len(self.data_cache)}")

def __len__(self):
return len(self.file_list) - self.seq_len_total + 1

def __getitem__(self, idx):
sequence_data = np.stack([self.data_cache[idx + i] for i in range(self.seq_len_total)])
inputs_full_sequence = sequence_data[:SEQ_LEN_FEATURE, :, :, :]

targets_density_sequence = sequence_data[SEQ_LEN_FEATURE:, 0, :, :]

return torch.from_numpy(inputs_full_sequence).float(), \
torch.from_numpy(targets_density_sequence).float()

Here is the code of U-net (with attention).

import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
from glob import glob
import os
from tqdm import tqdm
import torch.nn as nn
import torch.nn.functional as F
import math
from einops import rearrange
from functools import partial

# ---- SETTINGS ----
dataset_dir = 'dataset/density_training'
total_epoch = 300
seq_len_feature = 10
seq_len_rollout = 25

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn

def forward(self, x, *args, **kwargs):
return self.fn(x, *args, **kwargs) + x

def Upsample(dim):
return nn.ConvTranspose2d(dim, dim, 4, 2, 1)

def Downsample(dim):
return nn.Conv2d(dim, dim, 4, 2, 1)


class SinusoidalPositionEmbeddings(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim

def forward(self, time):
device = time.device
half_dim = self.dim // 2
embeddings = math.log(10000) / (half_dim - 1)
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
embeddings = time[:, None] * embeddings[None, :]
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
return embeddings

class ConvNextBlock(nn.Module):

def __init__(self, dim, dim_out, *, time_emb_dim=None, mult=2, norm=True):
super().__init__()
self.mlp = (
nn.Sequential(nn.GELU(), nn.Linear(time_emb_dim, dim))
if time_emb_dim is not None else None
)

self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, groups=dim)

self.net = nn.Sequential(
nn.GroupNorm(1, dim) if norm else nn.Identity(),
nn.Conv2d(dim, dim_out * mult, 3, padding=1),
nn.GELU(),
nn.GroupNorm(1, dim_out * mult),
nn.Conv2d(dim_out * mult, dim_out, 3, padding=1),
)

self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

def forward(self, x, time_emb=None):
h = self.ds_conv(x)

if self.mlp is not None and time_emb is not None:
assert time_emb is not None, "time embedding must be passed in"
condition = self.mlp(time_emb)
h = h + rearrange(condition, "b c -> b c 1 1")

h = self.net(h)
return h + self.res_conv(x)

class Attention(nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super().__init__()
self.scale = dim_head**-0.5
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)

def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x).chunk(3, dim=1)
q, k, v = map(
lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
)
q = q * self.scale

sim = torch.einsum("b h d i, b h d j -> b h i j", q, k)
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
attn = sim.softmax(dim=-1)

out = torch.einsum("b h i j, b h d j -> b h i d", attn, v)
out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
return self.to_out(out)


class LinearAttention(nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super().__init__()
self.scale = dim_head**-0.5
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)

self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1),
nn.GroupNorm(1, dim))

def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x).chunk(3, dim=1)
q, k, v = map(
lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
)

q = q.softmax(dim=-2)
k = k.softmax(dim=-1)

q = q * self.scale
context = torch.einsum("b h d n, b h e n -> b h d e", k, v)

out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
return self.to_out(out)

class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.GroupNorm(1, dim)

def forward(self, x):
x = self.norm(x)
return self.fn(x)


class Unet(nn.Module):
def __init__(
self,
dim,
init_dim=None,
out_dim=None,
dim_mults=(1, 2, 4, 8),
channels=3,
with_time_emb=True,
resnet_block_groups=8,
use_convnext=True,
convnext_mult=2,
):
super().__init__()

# determine dimensions
self.channels = channels

init_dim = init_dim if init_dim is not None else dim // 3 * 2
self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3)

dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))

if use_convnext:
block_klass = partial(ConvNextBlock, mult=convnext_mult)
else:
raise NotImplementedError()

# time embeddings
if with_time_emb:
time_dim = dim * 4
self.time_mlp = nn.Sequential(
SinusoidalPositionEmbeddings(dim),
nn.Linear(dim, time_dim),
nn.GELU(),
nn.Linear(time_dim, time_dim),
)

else:
time_dim = None
self.time_mlp = None
self.cond_mlp = None
self.sim_mlp = None

# layers
self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([])
num_resolutions = len(in_out)

for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (num_resolutions - 1)

self.downs.append(
nn.ModuleList(
[
block_klass(dim_in, dim_out, time_emb_dim=time_dim),
block_klass(dim_out, dim_out, time_emb_dim=time_dim),
Residual(PreNorm(dim_out, LinearAttention(dim_out))),
Downsample(dim_out) if not is_last else nn.Identity(),
]
)
)

mid_dim = dims[-1]
self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)

for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (num_resolutions - 1)

self.ups.append(
nn.ModuleList(
[
block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),
block_klass(dim_in, dim_in, time_emb_dim=time_dim),
Residual(PreNorm(dim_in, LinearAttention(dim_in))),
Upsample(dim_in) if not is_last else nn.Identity(),
]
)
)

out_dim = out_dim if out_dim is not None else channels
self.final_conv = nn.Sequential(
block_klass(dim, dim), nn.Conv2d(dim, out_dim, 1)
)

def forward(self, x, time):
x = self.init_conv(x)

t = self.time_mlp(time) if self.time_mlp is not None else None

h = []

# downsample
for block1, block2, attn, downsample in self.downs:
x = block1(x, t)
x = block2(x, t)
x = attn(x)
h.append(x)
x = downsample(x)

# bottleneck
x = self.mid_block1(x, t)
x = self.mid_attn(x)
x = self.mid_block2(x, t)

# upsample
for block1, block2, attn, upsample in self.ups:
x = torch.cat((x, h.pop()), dim=1)
x = block1(x, t)
x = block2(x, t)
x = attn(x)
x = upsample(x)

return self.final_conv(x)


# ---- DATASET ----
class DensitySequenceDataset(Dataset):
def __init__(self, dataset_dir, seqlen=10):
self.file_list = sorted(glob(os.path.join(dataset_dir, "*.npz")))
self.seqlen = seqlen

def __len__(self):
return len(self.file_list) - self.seqlen

def __getitem__(self, idx):
seqdata2d = []
for iter_idx in range(self.seqlen):
single_data = np.load(self.file_list[idx+iter_idx])['arr_0'] # shape: (9, H, W)
seqdata2d.append(single_data)
seqdata2d = np.array(seqdata2d)
data = torch.from_numpy(seqdata2d).float()
return data

dataset = DensitySequenceDataset(dataset_dir, seqlen=seq_len_feature+seq_len_rollout)
batch_size = 4
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

mean_std = np.load('density_stats.npz')
mean = mean_std['mean']
std = mean_std['std']

mean_device = torch.from_numpy(mean).to(device)
std_device = torch.from_numpy(std).to(device)

model = Unet(dim=32, channels=9, out_dim=1, with_time_emb=False).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()
print(model)

checkpoint_path = 'best_advunet_turb.pth'
best_loss = float('inf')
last_epoch = -1
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
last_epoch = checkpoint['epoch']
best_loss = checkpoint['loss'] / len(loader)

model.train()
for epoch in range(total_epoch):
if epoch <= last_epoch:
continue
total_loss = 0
loop = tqdm(loader, desc=f"Epoch {epoch+1}/{total_epoch}")
for batch in loop:
batch = batch.to(device)
batch = (batch - mean_device) / std_device

inputs = batch[:, :seq_len_feature].squeeze(2)#.clone().detach()
target = batch[:, seq_len_feature:].squeeze(2)
time_embd = torch.linspace(0, 1, steps=seq_len_feature+seq_len_rollout, device=batch.device).unsqueeze(0).repeat(batch_size, 1)

optimizer.zero_grad()
pred_frames = []
gt_frames = []
for iter_r in range(seq_len_rollout):
preds = model(inputs, time_embd[:, iter_r:(seq_len_feature+iter_r)])
pred_frames.append(preds)
gt_frames.append(target[:, iter_r])

new_inputs = torch.cat([
inputs[:, 1:].clone(), # drop oldest
preds.detach() # add newest
], dim=1)
inputs = new_inputs # overwrite safely
pred_frames = torch.stack(pred_frames, dim=1).squeeze()
gt_frames = torch.stack(gt_frames, dim=1)
loss = loss_fn(pred_frames, gt_frames)
try:
torch.autograd.set_detect_anomaly(True)
loss.backward()
except RuntimeError as e:
print("BACKWARD ERROR:", e)
raise e
optimizer.step()

total_loss += loss.item()
loop.set_postfix(loss=loss.item())

avg_loss = total_loss / len(loader)
print(f"Epoch {epoch+1}, Avg Loss: {avg_loss:.6f}")

if avg_loss < best_loss:
best_loss = avg_loss
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': total_loss
}, checkpoint_path)
#break

# ---- FINAL SAVE ----
torch.save({
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': total_loss
}, 'final_advunet_turb.pth')

After training it, you will get a pth model file with size 173 mb. The size of the model portrays the total weights and biases inside the model, so the bigger the better, but it will burden your VRAM (or RAM if you don’t use GPU).

Here is the code to produce the result. Run the code and you will do a rollout prediction for the next 100 timesteps ahead. In the end, you will get a stack of prediction images in your target directory. To be more convenient, I combine all of the images and produce their GIF version.

import torch
import numpy as np
from glob import glob
import os
import torch.nn as nn
import torch.nn.functional as F
import math
from einops import rearrange
from functools import partial
import matplotlib.pyplot as plt

# ---- SETTINGS ----
dataset_dir = 'dataset/density_training'
total_epoch = 300
seq_len_feature = 10
#seq_len_rollout = 10
dataset_testdir = 'dataset/density_test'

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn

def forward(self, x, *args, **kwargs):
return self.fn(x, *args, **kwargs) + x

def Upsample(dim):
return nn.ConvTranspose2d(dim, dim, 4, 2, 1)

def Downsample(dim):
return nn.Conv2d(dim, dim, 4, 2, 1)


class SinusoidalPositionEmbeddings(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim

def forward(self, time):
device = time.device
half_dim = self.dim // 2
embeddings = math.log(10000) / (half_dim - 1)
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
embeddings = time[:, None] * embeddings[None, :]
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
return embeddings

class ConvNextBlock(nn.Module):

def __init__(self, dim, dim_out, *, time_emb_dim=None, mult=2, norm=True):
super().__init__()
self.mlp = (
nn.Sequential(nn.GELU(), nn.Linear(time_emb_dim, dim))
if time_emb_dim is not None else None
)

self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, groups=dim)

self.net = nn.Sequential(
nn.GroupNorm(1, dim) if norm else nn.Identity(),
nn.Conv2d(dim, dim_out * mult, 3, padding=1),
nn.GELU(),
nn.GroupNorm(1, dim_out * mult),
nn.Conv2d(dim_out * mult, dim_out, 3, padding=1),
)

self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

def forward(self, x, time_emb=None):
h = self.ds_conv(x)

if self.mlp is not None and time_emb is not None:
assert time_emb is not None, "time embedding must be passed in"
condition = self.mlp(time_emb)
h = h + rearrange(condition, "b c -> b c 1 1")

h = self.net(h)
return h + self.res_conv(x)

class Attention(nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super().__init__()
self.scale = dim_head**-0.5
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)

def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x).chunk(3, dim=1)
q, k, v = map(
lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
)
q = q * self.scale

sim = torch.einsum("b h d i, b h d j -> b h i j", q, k)
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
attn = sim.softmax(dim=-1)

out = torch.einsum("b h i j, b h d j -> b h i d", attn, v)
out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
return self.to_out(out)


class LinearAttention(nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super().__init__()
self.scale = dim_head**-0.5
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)

self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1),
nn.GroupNorm(1, dim))

def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x).chunk(3, dim=1)
q, k, v = map(
lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
)

q = q.softmax(dim=-2)
k = k.softmax(dim=-1)

q = q * self.scale
context = torch.einsum("b h d n, b h e n -> b h d e", k, v)

out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
return self.to_out(out)

class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.GroupNorm(1, dim)

def forward(self, x):
x = self.norm(x)
return self.fn(x)


class Unet(nn.Module):
def __init__(
self,
dim,
init_dim=None,
out_dim=None,
dim_mults=(1, 2, 4, 8),
channels=3,
with_time_emb=True,
resnet_block_groups=8,
use_convnext=True,
convnext_mult=2,
):
super().__init__()

# determine dimensions
self.channels = channels

init_dim = init_dim if init_dim is not None else dim // 3 * 2
self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3)

dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))

if use_convnext:
block_klass = partial(ConvNextBlock, mult=convnext_mult)
else:
raise NotImplementedError()

# time embeddings
if with_time_emb:
time_dim = dim * 4
self.time_mlp = nn.Sequential(
SinusoidalPositionEmbeddings(dim),
nn.Linear(dim, time_dim),
nn.GELU(),
nn.Linear(time_dim, time_dim),
)

else:
time_dim = None
self.time_mlp = None
self.cond_mlp = None
self.sim_mlp = None

# layers
self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([])
num_resolutions = len(in_out)

for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (num_resolutions - 1)

self.downs.append(
nn.ModuleList(
[
block_klass(dim_in, dim_out, time_emb_dim=time_dim),
block_klass(dim_out, dim_out, time_emb_dim=time_dim),
Residual(PreNorm(dim_out, LinearAttention(dim_out))),
Downsample(dim_out) if not is_last else nn.Identity(),
]
)
)

mid_dim = dims[-1]
self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)

for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (num_resolutions - 1)

self.ups.append(
nn.ModuleList(
[
block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),
block_klass(dim_in, dim_in, time_emb_dim=time_dim),
Residual(PreNorm(dim_in, LinearAttention(dim_in))),
Upsample(dim_in) if not is_last else nn.Identity(),
]
)
)

out_dim = out_dim if out_dim is not None else channels
self.final_conv = nn.Sequential(
block_klass(dim, dim), nn.Conv2d(dim, out_dim, 1)
)

def forward(self, x, time):
x = self.init_conv(x)

t = self.time_mlp(time) if self.time_mlp is not None else None

h = []

# downsample
for block1, block2, attn, downsample in self.downs:
x = block1(x, t)
x = block2(x, t)
x = attn(x)
h.append(x)
x = downsample(x)

# bottleneck
x = self.mid_block1(x, t)
x = self.mid_attn(x)
x = self.mid_block2(x, t)

# upsample
for block1, block2, attn, upsample in self.ups:
x = torch.cat((x, h.pop()), dim=1)
x = block1(x, t)
x = block2(x, t)
x = attn(x)
x = upsample(x)

return self.final_conv(x)



model = Unet(dim=32, channels=9, out_dim=1, with_time_emb=False).to(device)
model.load_state_dict(torch.load("best_advunet_turb.pth", map_location=device)["model_state_dict"])

mean_std = np.load('density_stats.npz')
mean = mean_std['mean']
std = mean_std['std']

mean_device = torch.from_numpy(mean).to(device)
std_device = torch.from_numpy(std).to(device)

file_list = sorted(glob(os.path.join(dataset_dir, "*.npz")))[-9:]
print(file_list)

seqdata2d = []
for iter_f in range(len(file_list)):
single_data = np.load(file_list[iter_f])['arr_0'] # shape: (9, H, W)
seqdata2d.append(single_data)
seqdata2d = np.array(seqdata2d)
feature = torch.from_numpy(seqdata2d).float().to(device)
feature = (feature - mean_device) / std_device
feature = feature.squeeze(1).unsqueeze(0)

file_list_test = sorted(glob(os.path.join(dataset_testdir, "*.npz")))
seqtruth = []
for iter_f in range(len(file_list_test)):
single_data = np.load(file_list_test[iter_f])['arr_0']
seqtruth.append(single_data)
seqtruth = np.array(seqtruth)


seqpred = []
with torch.no_grad():
time_embd = torch.linspace(0, 1, steps=100, device=device).unsqueeze(0).repeat(1, 1)
for iter_step in range(100):
print('cek shape feature', feature.size())
pred = model(feature, time_embd[:, iter_step:(seq_len_feature+iter_step)])
pred_real = pred * std_device + mean_device
pred_real = pred_real.detach().squeeze().cpu().numpy()
seqpred.append(pred_real)

#re calculate feature
feature[:-1] = feature[1:]
feature[-1] = pred.detach()

seqpred = np.array(seqpred)

#assert len(seqpred) == len(seqtruth)
vmin = np.min(seqtruth)
vmax = np.max(seqtruth)
for iter_p in range(len(seqpred)):
fig, axes = plt.subplots(1, 2, figsize=(10, 6))

im0 = axes[0].imshow(seqtruth[iter_p,0], cmap='viridis', vmin=vmin, vmax=vmax)
axes[0].set_title(f"Ground Truth - t={iter_p}")
axes[0].axis('off')

im1 = axes[1].imshow(seqpred[iter_p], cmap='viridis', vmin=vmin, vmax=vmax)
axes[1].set_title(f"Prediction - t={iter_p}")
axes[1].axis('off')

#fig.colorbar(im0, ax=axes.ravel().tolist(), shrink=0.6, label="Density")
fig.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04)
plt.suptitle(f"Timestep {iter_p+1} - Ground Truth vs. Prediction (U-NET)", fontsize=14)
#plt.tight_layout()
plt.tight_layout(rect=[0, 0, 1, 0.95]) # reserve space for suptitle

fig.savefig(f"res_turb_advunet/comparison_t{iter_p:03d}.png", dpi=150)
plt.close(fig)

This is the result

As you can see, somehow, the movement of those “bubbles” is too fast, but the model captures the movement patterns in the dataset.

Here is the code for the FNO model.

import torch
from torch.utils.data import Dataset, DataLoader
from neuralop import LpLoss, H1Loss
import numpy as np
from glob import glob
import os
from tqdm import tqdm
import torch.nn as nn
import torch.nn.functional as F

from functools import reduce

import operator

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#setting
dataset_dir = 'dataset/density_training'
# Sequence lengths
SEQ_LEN_FEATURE = 10 # Number of past time-nodes as input for prediction
SEQ_LEN_ROLLOUT = 25 # Number of future timesteps to predict in an unrolling manner
C_DATA_PER_TIMESTEP = 10
BATCH_SIZE = 8
TOTAL_EPOCHS = 300
WIDTH_IMG = 128
HEIGHT_IMG = 64

torch.manual_seed(0)
np.random.seed(0)

def compl_mul2d(input, weights):
# (batch, in_channel, x, y), (in_channel, out_channel, x, y) -> (batch, out_channel, x, y)
return torch.einsum("bixy,ioxy->boxy", input, weights)


class DensityTurbulenceDataset(Dataset):
def __init__(self, dataset_dir, seq_len_total):
self.file_list = sorted(glob(os.path.join(dataset_dir, "*.npz")))
self.seq_len_total = seq_len_total # SEQ_LEN_FEATURE + SEQ_LEN_ROLLOUT

self.data_cache = []
print(f"Loading dataset from {dataset_dir} into memory...")
for f_path in tqdm(self.file_list, desc=f"Loading {os.path.basename(dataset_dir)}"):
self.data_cache.append(np.load(f_path)['arr_0'])
print(f"Dataset from {dataset_dir} loaded. Total files: {len(self.data_cache)}")

def __len__(self):
return len(self.file_list) - self.seq_len_total + 1

def __getitem__(self, idx):
sequence_data = np.stack([self.data_cache[idx + i] for i in range(self.seq_len_total)])
inputs_full_sequence = sequence_data[:SEQ_LEN_FEATURE, :, :, :]

targets_density_sequence = sequence_data[SEQ_LEN_FEATURE:, 0, :, :]

return torch.from_numpy(inputs_full_sequence).float(), \
torch.from_numpy(targets_density_sequence).float()

################################################################
# fourier layer
################################################################

class SpectralConv2d_fast(nn.Module):
def __init__(self, in_channels, out_channels, modes1, modes2):
super(SpectralConv2d_fast, self).__init__()

self.in_channels = in_channels
self.out_channels = out_channels
self.modes1 = modes1 # Number of Fourier modes to multiply, at most floor(N/2) + 1
self.modes2 = modes2

self.scale = (1 / (in_channels * out_channels))
self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat))
self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat))

def forward(self, x):
batchsize = x.shape[0]
x_ft = torch.fft.rfft2(x, norm='ortho')

# Multiply relevant Fourier modes
out_ft = torch.zeros(batchsize, self.out_channels, x.size(-2), x.size(-1) // 2 + 1, dtype=torch.cfloat, device=x.device)
out_ft[:, :, :self.modes1, :self.modes2] = \
compl_mul2d(x_ft[:, :, :self.modes1, :self.modes2], self.weights1)
out_ft[:, :, -self.modes1:, :self.modes2] = \
compl_mul2d(x_ft[:, :, -self.modes1:, :self.modes2], self.weights2)

# Return to physical space
x = torch.fft.irfft2(out_ft, s=(x.size(-2), x.size(-1)), norm='ortho')
return x

class MLP(nn.Module):
def __init__(self, in_channels, out_channels, mid_channels):
super(MLP, self).__init__()
self.mlp1 = nn.Conv2d(in_channels, mid_channels, 1)
self.mlp2 = nn.Conv2d(mid_channels, out_channels, 1)

def forward(self, x):
x = self.mlp1(x)
x = F.gelu(x)
x = self.mlp2(x)
return x

class SimpleBlock2d(nn.Module):
def __init__(self, modes1, modes2, width):
super(SimpleBlock2d, self).__init__()

self.modes1 = modes1
self.modes2 = modes2
self.width = width
self.padding = 8
self.p = nn.Linear(12,self.width)


self.conv0 = SpectralConv2d_fast(self.width, self.width, self.modes1, self.modes2)
self.conv1 = SpectralConv2d_fast(self.width, self.width, self.modes1, self.modes2)
self.conv2 = SpectralConv2d_fast(self.width, self.width, self.modes1, self.modes2)
self.conv3 = SpectralConv2d_fast(self.width, self.width, self.modes1, self.modes2)
self.mlp0 = MLP(self.width, self.width, self.width)
self.mlp1 = MLP(self.width, self.width, self.width)
self.mlp2 = MLP(self.width, self.width, self.width)
self.mlp3 = MLP(self.width, self.width, self.width)
self.w0 = nn.Conv2d(self.width, self.width, 1)
self.w1 = nn.Conv2d(self.width, self.width, 1)
self.w2 = nn.Conv2d(self.width, self.width, 1)
self.w3 = nn.Conv2d(self.width, self.width, 1)
self.norm = nn.InstanceNorm2d(self.width)

self.q = MLP(self.width, 1, self.width * 4) # output channel is 1: u(x, y)


def forward(self, x):
x = self.p(x)
x = x.permute(0, 4, 1, 2, 3)
x = x.view(x.shape[0], -1, x.shape[3], x.shape[4])

x1 = self.norm(self.conv0(self.norm(x)))
x1 = self.mlp0(x1)
x2 = self.w0(x)
x=x1+x2+x
x = F.gelu(x)

x1 = self.norm(self.conv1(self.norm(x)))
x1 = self.mlp1(x1)
x2 = self.w1(x)
x = x1 + x2 + x
x = F.gelu(x)

x1 = self.norm(self.conv2(self.norm(x)))
x1 = self.mlp2(x1)
x2 = self.w2(x)
x = x1 + x2 + x
x = F.gelu(x)

x1 = self.norm(self.conv3(self.norm(x)))
x1 = self.mlp3(x1)
x2 = self.w3(x)
x = x1 + x2 + x
x = F.gelu(x)

x=self.q(x)
x = x.permute(0, 2, 3, 1)

return x


class Net2d(nn.Module):
def __init__(self, modes, width):
super(Net2d, self).__init__()

"""
A wrapper function
"""

self.conv1 = SimpleBlock2d(modes, modes, width)

def forward(self, x):
x = self.conv1(x)
return x

def count_params(self):
c = 0
for p in self.parameters():
c += reduce(operator.mul, list(p.size()))

return c

def create_embd_dim(batch_dim):
x = torch.from_numpy(np.linspace(0, 1, HEIGHT_IMG).reshape(1, HEIGHT_IMG).repeat(WIDTH_IMG, axis=0))
y = torch.from_numpy(np.linspace(0, 1, WIDTH_IMG).reshape(WIDTH_IMG, 1).repeat(HEIGHT_IMG, axis=1))

x = x.view(1, 1, 1, WIDTH_IMG, HEIGHT_IMG).to(device).float()
y = y.view(1, 1, 1, WIDTH_IMG, HEIGHT_IMG).to(device).float()

x = x.repeat(batch_dim, 1, 1, 1, 1) # Shape: (8, 1, 1, 128, 64)
y = y.repeat(batch_dim, 1, 1, 1, 1) # Shape: (8, 1, 1, 128, 64)
return x, y

dataset = DensityTurbulenceDataset(
dataset_dir,
seq_len_total=SEQ_LEN_FEATURE + SEQ_LEN_ROLLOUT
)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)


# Load pre-calculated mean and std for normalization
mean_std = np.load('density_stats.npz')
mean = mean_std['mean']
std = mean_std['std']

model = Net2d(16, 20).to(device)

# Convert mean/std to tensors and move to device
mean_device = torch.from_numpy(mean).float().to(device)
std_device = torch.from_numpy(std).float().to(device)


learning_rate = 0.001
scheduler_step = 100
scheduler_gamma = 0.5

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=scheduler_gamma)
#l2loss = LpLoss(size_average=False)
l2loss = LpLoss(d=2, p=2, reduce_dims=(0,1))
h1_loss = H1Loss(d=2, reduce_dims=(0,1))

# ---- CHECKPOINT LOADING ----
checkpoint_path = 'best_mgfnol2losshlossneuralop_turbulence_unroll.pth'
best_loss = float('inf')
last_epoch = -1
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
last_epoch = checkpoint['epoch']
best_loss = checkpoint['loss']
print(f"Loaded checkpoint from epoch {last_epoch}, best loss: {best_loss:.6f}")

# ---- TRAINING LOOP ----
model.train()
for epoch in range(TOTAL_EPOCHS):
if epoch <= last_epoch:
continue
total_loss = 0
loop = tqdm(loader, desc=f"Epoch {epoch+1}/{TOTAL_EPOCHS}")

for inputs_seq_raw, targets_seq_raw in loop:
inputs_seq_raw = inputs_seq_raw.to(device)
targets_seq_raw = targets_seq_raw.to(device)
inputs_seq_norm = (inputs_seq_raw - mean_device) / std_device
targets_seq_norm = (targets_seq_raw - mean_device) / std_device

optimizer.zero_grad()

# Start the unrolling process
current_input_sequence = inputs_seq_norm # (B, SEQ_LEN_FEATURE, H, W)

predicted_density_norm = []
ground_truth_density_norm = []

for r_step in range(SEQ_LEN_ROLLOUT):
x, y = create_embd_dim(len(current_input_sequence))
used_input_sequence = torch.cat((current_input_sequence, x, y), dim=1)

used_input_sequence = used_input_sequence.permute(0, 2, 3, 4, 1)

pred_norm = model(used_input_sequence)
pred_norm = pred_norm.squeeze().unsqueeze(1)

predicted_density_norm.append(pred_norm)


current_gt_frame_norm = targets_seq_norm[:, r_step:r_step+1, :, :] # (B, 1, H, W)
ground_truth_density_norm.append(current_gt_frame_norm)

prepared_next_input = pred_norm.unsqueeze(1)

current_input_sequence = torch.cat([
current_input_sequence[:, 1:, :, :], # Drop oldest (1st) frame
prepared_next_input # Add the new predicted normalized frame
], dim=1) # Concatenate along the 'sequence' dimension (dim 1)

predicted_density_norm = torch.cat(predicted_density_norm, dim=1)
ground_truth_density_norm = torch.cat(ground_truth_density_norm, dim=1)

loss = l2loss(predicted_density_norm, ground_truth_density_norm) + 0.2 * h1_loss(predicted_density_norm, ground_truth_density_norm)

try:
loss.backward()
except RuntimeError as e:
print("BACKWARD ERROR:", e)
raise e # Re-raise to stop training and inspect

optimizer.step()

total_loss += loss.item()
loop.set_postfix(loss=loss.item())

scheduler.step()

avg_loss = total_loss / len(loader)
print(f"Epoch {epoch+1}, Avg Loss: {avg_loss:.6f}")

if avg_loss < best_loss:
best_loss = avg_loss
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': best_loss
}, checkpoint_path)
print(f"Saved new best model at epoch {epoch+1} with loss {best_loss:.6f}")

# ---- FINAL SAVE ----
torch.save({
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': total_loss / len(loader) # Save final average loss
}, 'final_mgfnol2losshlossneuralop_turbulence_unroll.pth')

print("Training complete.")

You will get a pth file model with size 20 MB. It is significantly lower than the U-net model. Here is the test code.

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from glob import glob
import os
import matplotlib.pyplot as plt

from functools import reduce
from functools import partial

import operator

# ---- CONFIGURATION SETTINGS ----
dataset_dir = 'dataset/density_test'
dataset_test_dir = 'dataset/density_test'
C_DATA_PER_TIMESTEP = 9 # Number of input channels in your .npz file per timestep (e.g., density, vx, vy, etc.)

SEQ_LEN_FEATURE = 10
SEQ_LEN_ROLLOUT = 25

WIDTH_IMG = 128
HEIGHT_IMG = 64

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

################################################################
# fourier layer
################################################################

class SpectralConv2d_fast(nn.Module):
def __init__(self, in_channels, out_channels, modes1, modes2):
super(SpectralConv2d_fast, self).__init__()

self.in_channels = in_channels
self.out_channels = out_channels
self.modes1 = modes1 # Number of Fourier modes to multiply, at most floor(N/2) + 1
self.modes2 = modes2

self.scale = (1 / (in_channels * out_channels))
self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat))
self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat))

def forward(self, x):
batchsize = x.shape[0]
# Compute Fourier coeffcients up to factor of e^(- something constant)
x_ft = torch.fft.rfft2(x, norm='ortho')

# Multiply relevant Fourier modes
out_ft = torch.zeros(batchsize, self.out_channels, x.size(-2), x.size(-1) // 2 + 1, dtype=torch.cfloat, device=x.device)
out_ft[:, :, :self.modes1, :self.modes2] = \
compl_mul2d(x_ft[:, :, :self.modes1, :self.modes2], self.weights1)
out_ft[:, :, -self.modes1:, :self.modes2] = \
compl_mul2d(x_ft[:, :, -self.modes1:, :self.modes2], self.weights2)

x = torch.fft.irfft2(out_ft, s=(x.size(-2), x.size(-1)), norm='ortho')
return x

class MLP(nn.Module):
def __init__(self, in_channels, out_channels, mid_channels):
super(MLP, self).__init__()
self.mlp1 = nn.Conv2d(in_channels, mid_channels, 1)
self.mlp2 = nn.Conv2d(mid_channels, out_channels, 1)

def forward(self, x):
x = self.mlp1(x)
x = F.gelu(x)
x = self.mlp2(x)
return x

class SimpleBlock2d(nn.Module):
def __init__(self, modes1, modes2, width):
super(SimpleBlock2d, self).__init__()
self.modes1 = modes1
self.modes2 = modes2
self.width = width
self.padding = 8
self.p = nn.Linear(12,self.width)


self.conv0 = SpectralConv2d_fast(self.width, self.width, self.modes1, self.modes2)
self.conv1 = SpectralConv2d_fast(self.width, self.width, self.modes1, self.modes2)
self.conv2 = SpectralConv2d_fast(self.width, self.width, self.modes1, self.modes2)
self.conv3 = SpectralConv2d_fast(self.width, self.width, self.modes1, self.modes2)
self.mlp0 = MLP(self.width, self.width, self.width)
self.mlp1 = MLP(self.width, self.width, self.width)
self.mlp2 = MLP(self.width, self.width, self.width)
self.mlp3 = MLP(self.width, self.width, self.width)
self.w0 = nn.Conv2d(self.width, self.width, 1)
self.w1 = nn.Conv2d(self.width, self.width, 1)
self.w2 = nn.Conv2d(self.width, self.width, 1)
self.w3 = nn.Conv2d(self.width, self.width, 1)
self.norm = nn.InstanceNorm2d(self.width)

self.q = MLP(self.width, 1, self.width * 4) # output channel is 1: u(x, y)


def forward(self, x):
x = self.p(x)
x = x.permute(0, 4, 1, 2, 3)
x = x.view(x.shape[0], -1, x.shape[3], x.shape[4])


x1 = self.norm(self.conv0(self.norm(x)))
x1 = self.mlp0(x1)
x2 = self.w0(x)
x=x1+x2+x
x = F.gelu(x)

x1 = self.norm(self.conv1(self.norm(x)))
x1 = self.mlp1(x1)
x2 = self.w1(x)
x = x1 + x2 + x
x = F.gelu(x)

x1 = self.norm(self.conv2(self.norm(x)))
x1 = self.mlp2(x1)
x2 = self.w2(x)
x = x1 + x2 + x
x = F.gelu(x)

x1 = self.norm(self.conv3(self.norm(x)))
x1 = self.mlp3(x1)
x2 = self.w3(x)
x = x1 + x2 + x
x = F.gelu(x)

x=self.q(x)
x = x.permute(0, 2, 3, 1)

return x


class Net2d(nn.Module):
def __init__(self, modes, width):
super(Net2d, self).__init__()

"""
A wrapper function
"""

self.conv1 = SimpleBlock2d(modes, modes, width)

def forward(self, x):
x = self.conv1(x)
return x

def count_params(self):
c = 0
for p in self.parameters():
c += reduce(operator.mul, list(p.size()))

return c

def create_embd_dim(batch_dim):
x = torch.from_numpy(np.linspace(0, 1, HEIGHT_IMG).reshape(1, HEIGHT_IMG).repeat(WIDTH_IMG, axis=0))
y = torch.from_numpy(np.linspace(0, 1, WIDTH_IMG).reshape(WIDTH_IMG, 1).repeat(HEIGHT_IMG, axis=1))

x = x.view(1, 1, 1, WIDTH_IMG, HEIGHT_IMG).to(device).float()
y = y.view(1, 1, 1, WIDTH_IMG, HEIGHT_IMG).to(device).float()

x = x.repeat(batch_dim, 1, 1, 1, 1) # Shape: (8, 1, 1, 128, 64)
y = y.repeat(batch_dim, 1, 1, 1, 1) # Shape: (8, 1, 1, 128, 64)
return x, y

def compl_mul2d(input, weights):
# (batch, in_channel, x, y), (in_channel, out_channel, x, y) -> (batch, out_channel, x, y)
return torch.einsum("bixy,ioxy->boxy", input, weights)


CHECKPOINT_TO_LOAD = 'best_mgfnol2losshlossneuralop_turbulence_unroll.pth'

# ---- INSTANTIATE MODEL, DATASET, OPTIMIZER, LOSS ----
# First, get a sample to determine H, W
dummy_input_path = glob(os.path.join(dataset_dir, "*.npz"))[0]
dummy_data = np.load(dummy_input_path)['arr_0']
_, H, W = dummy_data.shape # Assuming shape (C_data, H, W)


model = Net2d(16, 20).to(device)

# Load the trained model state
if os.path.exists(CHECKPOINT_TO_LOAD):
checkpoint = torch.load(CHECKPOINT_TO_LOAD, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Successfully loaded model from {CHECKPOINT_TO_LOAD}")
else:
print(f"Error: Trained model checkpoint not found at {CHECKPOINT_TO_LOAD}. Please ensure training was completed and the file exists.")
exit()

model.eval()

# Load mean and std for normalization (must be the same as used in training)
try:
mean_std = np.load('density_stats.npz')
mean_device = torch.from_numpy(mean_std['mean']).float().to(device)
std_device = torch.from_numpy(mean_std['std']).float().to(device)
if mean_device.dim() == 2:
mean_device = mean_device.unsqueeze(0) # Becomes (1, H, W)
std_device = std_device.unsqueeze(0) # Becomes (1, H, W)
except FileNotFoundError:
print("Error: 'density_stats.npz' not found. This file is crucial for normalization matching training.")
exit()
except Exception as e:
print(f"An error occurred loading or processing 'density_stats.npz' for test: {e}")
exit()

file_list = sorted(glob(os.path.join(dataset_dir, "*.npz")))[-SEQ_LEN_FEATURE:]
print(file_list)

seqdata2d = []
for iter_f in range(len(file_list)):
single_data = np.load(file_list[iter_f])['arr_0'] # shape: (9, H, W)
seqdata2d.append(single_data)
seqdata2d = np.array(seqdata2d)
feature = torch.from_numpy(seqdata2d).float().to(device)
feature = (feature - mean_device) / std_device
feature = feature.squeeze(1).unsqueeze(0)

file_list_test = sorted(glob(os.path.join(dataset_test_dir, "*.npz")))
seqtruth = []
for iter_f in range(len(file_list_test)):
single_data = np.load(file_list_test[iter_f])['arr_0']
seqtruth.append(single_data)
seqtruth = np.array(seqtruth)

seqpred = []
current_input_sequence_norm = feature
current_density_frame_norm = current_input_sequence_norm[:, -1:, :, :] # (B, 1, H, W)

with torch.no_grad():
for iter_step in range(100):
x, y = create_embd_dim(len(current_input_sequence_norm))
x = x.squeeze(1)
y = y.squeeze(1)

used_input_sequence = torch.cat((current_input_sequence_norm, x, y), dim=1).unsqueeze(2)
used_input_sequence = used_input_sequence.permute(0, 2, 3, 4, 1)


pred_norm = model(used_input_sequence)
pred_norm = pred_norm.squeeze()

next_predicted_density_frame_denorm = (pred_norm * std_device) + mean_device
next_predicted_density_frame_denorm = next_predicted_density_frame_denorm.detach().cpu().numpy()
next_input = pred_norm.unsqueeze(0).unsqueeze(0)
seqpred.append(next_predicted_density_frame_denorm)

current_input_sequence_norm = torch.cat([
current_input_sequence_norm[:, 1:, :, :], # Drop oldest (1st) frame
next_input # Add the new predicted normalized frame
], dim=1) # Concatenate along the 'sequence' dimension (dim 1)

seqpred = np.array(seqpred)

#assert len(seqpred) == len(seqtruth)
vmin = np.min(seqtruth)
vmax = np.max(seqtruth)
for iter_p in range(len(seqpred)):
fig, axes = plt.subplots(1, 2, figsize=(10, 6))

im0 = axes[0].imshow(seqtruth[iter_p,0], cmap='viridis', vmin=vmin, vmax=vmax)
axes[0].set_title(f"Ground Truth - t={iter_p}")
axes[0].axis('off')

im1 = axes[1].imshow(seqpred[iter_p], cmap='viridis', vmin=vmin, vmax=vmax)
axes[1].set_title(f"Prediction - t={iter_p}")
axes[1].axis('off')

#fig.colorbar(im0, ax=axes.ravel().tolist(), shrink=0.6, label="Density")
fig.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04)
plt.suptitle(f"Timestep {iter_p+1} - Ground Truth vs. Prediction (FNO)", fontsize=14)
#plt.tight_layout()
plt.tight_layout(rect=[0, 0, 1, 0.95]) # reserve space for suptitle

fig.savefig(f"res_turb_mgfno2/comparison_t{iter_p:03d}.png", dpi=150)
plt.close(fig)

This is the FNO result.

As you can see, in the FNO version, there is no too-fast motion, and its movement pattern is captured better than U-net.

Closing

Based on the authors of https://github.com/tum-pbs/autoreg-pde-diffusion, they said sufficiently large models are necessary to make the model learn better for more complex tasks. So, our models here might be improved if we add more processing layers for them. After this, I also want to recheck the code of U-net because its result is significantly poorer than FNO despite the gap size of the model. To be honest, I am still quite skeptical of their gap performances.

See you in another experiment.


Delve deeper into Fourier Neural Operator: Explanation and use case for simple turbulence… was originally published in Level Up Coding on Medium, where people are continuing the conversation by highlighting and responding to this story.


This content originally appeared on Level Up Coding – Medium and was authored by Muhammad Ryan