最优传输采样器教程¶
In [1]
已复制!
import math
import os
import time
import copy
import matplotlib.pyplot as plt
import numpy as np
import torch
from bionemo.moco.interpolants import EquivariantOTSampler, OTSampler
from sklearn.datasets import make_moons
```python
1.1 从标准高斯分布 ($\mathrm{x}_0 \sim \pi_0$,橙色) 中采样 100 个点,再从双月形分布 ($\mathrm{x}_1 \sim \pi_1$,蓝色) 中采样另外 100 个点。使用灰色线条绘制对 ($x_0^i, x_1^i$) 之间的线性插值。¶
In [2]
已复制!
def sample_moons(n, normalize = False):
x1, _ = make_moons(n_samples=n, noise=0.08)
x1 = torch.Tensor(x1)
x1 = x1 * 3 - 1
if normalize:
x1 = (x1 - x1.mean(0))/x1.std(0) * 2
return x1
def sample_gaussian(n, dim = 2):
return torch.randn(n, dim)
```python
In [3]
已复制!
# Sample x0 and x1
x1 = sample_moons(100, normalize=True).numpy()
x0 = sample_gaussian(100).numpy()
# Plot data points and linear interpolation
plt.scatter(x1[:, 0], x1[:, 1], label='$x_0$')
plt.scatter(x0[:, 0], x0[:, 1], label='$x_1$')
x0 = np.asarray(x0)
x1 = np.asarray(x1)
for i in range(len(x1)):
plt.plot([x0[i, 0], x1[i, 0]], [x0[i, 1], x1[i, 1]], color='k', alpha=0.2)
plt.legend()
```python
Out[3]
<matplotlib.legend.Legend at 0x78315adebca0>
1.2 初始化 OT 采样器并采样新的 $(x_0, x_1)$ 对,以最小化整个批次的传输成本。使用灰色线条绘制新对 ($x_0^i, x_1^i$) 之间的线性插值。我们可以看到,插值轨迹的交叉减少了,传输成本也降低了。¶
In [4]
已复制!
# Initialize the OTSampler
ot_sampler = OTSampler(method="exact", num_threads=1)
# Sample new pairs from the OTSampler, mask is not used in this example
# Replace is set to False, so no duplicates are allowed
# Sort is set to "x0", so the order of output x0 is the same as input x0
ot_sampled_x0, ot_sampled_x1, mask = ot_sampler.apply_ot(
torch.Tensor(x0),
torch.Tensor(x1),
mask=None, replace=False, sort="x0")
# Convert the sampled tensors to numpy arrays
ot_sampled_x0 = ot_sampled_x0.numpy()
ot_sampled_x1 = ot_sampled_x1.numpy()
```python
In [5]
已复制!
# Plot data points and linear interpolation
plt.scatter(ot_sampled_x1[:, 0], ot_sampled_x1[:, 1], label='$x_0$')
plt.scatter(ot_sampled_x0[:, 0], ot_sampled_x0[:, 1], label='$x_1$')
for i in range(len(x1)):
plt.plot(
[ot_sampled_x0[i, 0], ot_sampled_x1[i, 0]],
[ot_sampled_x0[i, 1], ot_sampled_x1[i, 1]],
color='k', alpha=0.2
)
plt.legend()
```python
Out[5]
<matplotlib.legend.Legend at 0x783152f9f4f0>
1.3 让我们看看 OT 如何帮助条件流匹配训练。我们将训练两个模型,一个使用 OT,另一个不使用 OT,并比较采样期间的流轨迹。¶
请注意,可以使用 'ot_type' 参数使用任何批量增强初始化 ContinuousFlowMatcher 对象。为了清晰起见,我们引入了先前初始化的 OT 采样器。
In [6]
已复制!
from bionemo.moco.interpolants import ContinuousFlowMatcher
from bionemo.moco.distributions.time import UniformTimeDistribution
from bionemo.moco.distributions.prior import GaussianPrior
def trainCFM(use_ot=False):
# Initialize model, optimizer, and flow matcher
dim = 2
hidden_size = 64
batch_size = 256
model = torch.nn.Sequential(
torch.nn.Linear(dim + 1, hidden_size),
torch.nn.SELU(),
torch.nn.Linear(hidden_size, hidden_size),
torch.nn.SELU(),
torch.nn.Linear(hidden_size, hidden_size),
torch.nn.SELU(),
torch.nn.Linear(hidden_size, dim),
)
optimizer = torch.optim.Adam(model.parameters())
uniform_time = UniformTimeDistribution()
moon_prior = GaussianPrior()
sigma = 0.1
cfm = ContinuousFlowMatcher(time_distribution=uniform_time,
prior_distribution=moon_prior,
sigma=sigma,
prediction_type="velocity")
# Place both the model and the interpolant on the same device
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(DEVICE)
cfm = cfm.to_device(DEVICE)
for k in range(10000):
optimizer.zero_grad()
shape = (batch_size, dim)
x0 = cfm.sample_prior(shape).to(DEVICE)
x1 = sample_moons(batch_size, normalize=False).to(DEVICE)
if use_ot:
x0, x1, mask = ot_sampler.apply_ot(
x0, x1,
mask=None, replace=False, sort="x0"
)
t = cfm.sample_time(batch_size)
xt = cfm.interpolate(x1, t, x0)
ut = cfm.calculate_target(x1, x0)
vt = model(torch.cat([xt, t[:, None]], dim=-1))
loss = cfm.loss(vt, ut, target_type="velocity").mean()
loss.backward()
optimizer.step()
if (k + 1) % 5000 == 0:
print(f"{k+1}: loss {loss.item():0.3f}")
return model, cfm
```python
In [7]
已复制!
# Train a model with OT
ot_model, ot_cfm = trainCFM(use_ot=True)
# Train a model without OT
no_ot_model, no_ot_cfm = trainCFM(use_ot=False)
```python
5000: loss 0.053 10000: loss 0.058 5000: loss 2.955 10000: loss 3.211
In [8]
已复制!
# Set up the sampling time schedule
from bionemo.moco.schedules.inference_time_schedules import LinearInferenceSchedule
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
inference_sched = LinearInferenceSchedule(nsteps = 100)
schedule = inference_sched.generate_schedule().to(DEVICE)
dts = inference_sched.discretize().to(DEVICE)
```python
In [9]
已复制!
# Sampling with the two trained models
inf_size = 1024
ot_sample = ot_cfm.sample_prior((inf_size, 2)) # Start with noise
no_ot_sample = copy.deepcopy(ot_sample) # Ensure the same starting point for both models
ot_sample, no_ot_sample = ot_sample.to(DEVICE), no_ot_sample.to(DEVICE)
ot_trajectory, no_ot_trajectory = [ot_sample], [no_ot_sample]
for dt, t in zip(dts, schedule):
full_t = torch.full((inf_size,), t).to(DEVICE)
ot_vt = ot_model(torch.cat([ot_sample, full_t[:, None]], dim=-1)) # calculate the vector field based on the definition of the model
ot_sample = ot_cfm.step(ot_vt, ot_sample, dt, full_t)
no_ot_vt = no_ot_model(torch.cat([no_ot_sample, full_t[:, None]], dim=-1)) # calculate the vector field based on the definition of the model
no_ot_sample = no_ot_cfm.step(no_ot_vt, no_ot_sample, dt, full_t)
ot_trajectory.append(ot_sample) # save the trajectory for plotting purposes
no_ot_trajectory.append(no_ot_sample) # save the trajectory for plotting purposes
```python
1.4 两个模型预测的流轨迹可视化。使用 OT(左图),流轨迹更直,因此与不使用 OT(右图)相比,传输成本更低。¶
In [10]
已复制!
ot_traj = torch.stack(ot_trajectory).cpu().detach().numpy()
no_ot_traj = torch.stack(no_ot_trajectory).cpu().detach().numpy()
n = 2000
# Assuming traj is your tensor and traj.shape = (N, 2000, 2)
# where N is the number of time points, 2000 is the number of samples at each time point, and 2 is for the x and y coordinates.
fig, ax = plt.subplots(1, 2, figsize=(12, 6))
# Plot the first time point in black
ax[0].scatter(ot_traj[0, :n, 0], ot_traj[0, :n, 1], s=10, alpha=0.8, c="black", label='Prior z(S)')
ax[1].scatter(no_ot_traj[0, :n, 0], no_ot_traj[0, :n, 1], s=10, alpha=0.8, c="black", label='Prior z(S)')
# Plot all the rest of the time points except the first and last in olive
for i in range(1, ot_traj.shape[0]-1):
ax[0].scatter(ot_traj[i, :n, 0], ot_traj[i, :n, 1], s=0.2, alpha=0.2, c="olive", zorder=1)
ax[1].scatter(no_ot_traj[i, :n, 0], no_ot_traj[i, :n, 1], s=0.2, alpha=0.2, c="olive", zorder=1)
# Plot the last time point in blue
ax[0].scatter(ot_traj[-1, :n, 0], ot_traj[-1, :n, 1], s=4, alpha=1, c="blue", label='z(0)')
ax[1].scatter(no_ot_traj[-1, :n, 0], no_ot_traj[-1, :n, 1], s=4, alpha=1, c="blue", label='z(0)')
# Add a second legend for "Flow" since we can't label in the loop directly
for i in range(2):
ax[i].scatter([], [], s=2, alpha=1, c="olive", label='Flow')
ax[i].legend()
# ax[i].set_aspect('equal')
ax[i].set_xticks([])
ax[i].set_yticks([])
ax[i].set_xlim(-5, 6)
ax[i].set_ylim(-4, 5)
if i == 0:
ax[i].set_title("With OT")
else:
ax[i].set_title("Without OT")
plt.subplots_adjust(wspace=0.05)
plt.show()
```python
In [11]
已复制!
first_points = no_ot_traj[0]
last_points = no_ot_traj[-1]
distances = ((last_points - first_points)**2).sum(-1)
average_distance = np.mean(distances)
print(f"Average Distance between First and Last Points without OT: {average_distance.item()}")
first_points = ot_traj[0]
last_points = ot_traj[-1]
distances = ((last_points - first_points)**2).sum(-1)
average_distance = np.mean(distances)
print(f"Average Distance between First and Last Points with OT: {average_distance.item()}")
```python
Average Distance between First and Last Points without OT: 4.119970321655273 Average Distance between First and Last Points with OT: 3.9200291633605957
In [12]
已复制!
def sum_of_squared_distances(trajectory):
"""
Calculate the sum of squared distances from start to mid and mid to end of a trajectory.
Parameters:
- trajectory: A numpy array of shape (N, D) where N is the number of points
in the trajectory and D is the dimensionality of the space.
Returns:
- Sum of squared distances (start to mid + mid to end).
"""
mid_idx = len(trajectory) // 2
start_point = trajectory[0]
mid_point = trajectory[mid_idx]
end_point = trajectory[-1]
start_to_mid_distance = np.linalg.norm(start_point - mid_point)
mid_to_end_distance = np.linalg.norm(mid_point - end_point)
return start_to_mid_distance**2 + mid_to_end_distance**2
# Calculate and print sum of squared distances for both trajectories
no_ot_sum_squared_distance = sum_of_squared_distances(no_ot_traj)
ot_sum_squared_distance = sum_of_squared_distances(ot_traj)
print("Sum of Squared Distances (start to mid + mid to end):")
print(f"Without OT: {no_ot_sum_squared_distance:.4f}")
print(f"With OT: {ot_sum_squared_distance:.4f}")
```python
Sum of Squared Distances (start to mid + mid to end): Without OT: 2667.9356 With OT: 2009.3874
2. 然后我们将介绍 Kabsch OT 采样器。Kabsch OT 采样器是 "等变 OT" 算法 (Klein et al.) 的实现。对于一批随机采样的噪声 ($\mathrm{x}_0$) 和数据 ($\mathrm{x}_1$),Kabsch OT 采样器将根据使用 Kabsch 算法 对齐零中心化 $(x_0, x_1)$ 后的 RMSD 来采样 $(x_0, x_1)$ 对。我们将演示如何使用一个简单的 2D 示例来使用 Kabsch OT 采样器。¶
In [6]
已复制!
# Define helper functions
def rotation_matrix(angle):
theta = (angle/180.) * np.pi
c, s = np.cos(theta), np.sin(theta)
return np.array([[c, -s], [s, c]])
def rotate(x, angle):
R = rotation_matrix(angle)
return x @ R.T
def plot_quadrilateral(x, axis, color='C0', marker='o', label=None):
assert x.shape == (4, 2)
axis.scatter(
x[:, 0], x[:, 1],
c=color, marker=marker, linewidths=1,
edgecolors='k', zorder=2, label=label
)
for i in range(len(x)):
if i < 3:
axis.plot([x[i, 0], x[i+1, 0]], [x[i, 1], x[i+1, 1]], c=color, zorder=1)
else:
axis.plot([x[i, 0], x[0, 0]], [x[i, 1], x[0, 1]], c=color, zorder=1)
return axis
```python
2.1 初始化 $\mathrm{k}_0$,其中包含两个样本。$k_0^0$ 是菱形,$k_0^1$ 是正方形。然后初始化 $\mathrm{k}_1$,它是旋转后的 $\mathrm{k}_0$。打乱 $\mathrm{k}_1$ 的顺序,使 $k_1^0$ 是旋转后的正方形,$k_1^1$ 是旋转后的菱形。在绘图时,$k_0^0$ 和 $k_1^0$ 显示为圆形点,$k_0^1$ 和 $k_1^1$ 显示为方形点。¶
In [7]
已复制!
# Initialize
k0 = np.array([
[[-2, 0], [0, 1], [2, 0], [0, -1]], # Rhombus
[[-1, 2], [-1, 4], [1, 4], [1, 2]], # Square
])
angles = [60, 25]
# Rotate and shuffle samples in k0 to create k1
k1 = np.array([rotate(k0[i], angles[i]) for i in [1, 0]])
markers = ['o', 's']
# Translate k0 and k1
k0 = np.array(k0)-2
k1 = np.array(k1)+2
```python
In [8]
已复制!
# Plot k0 and k1
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
for i in range(len(k0)):
plot_quadrilateral(k0[i], ax, color='C1', marker=markers[i], label='$k_0^%d$'%i)
plot_quadrilateral(k1[i], ax, color='C0', marker=markers[i], label='$k_1^%d$'%i)
# Calculate centroids of k0 and k1
centroid_k0 = np.mean(k0[i], axis=0)
centroid_k1 = np.mean(k1[i], axis=0)
# Plot a red line connecting the centroids
ax.plot(*zip(centroid_k0, centroid_k1), color='red', linewidth=1, linestyle='--')
ax.legend()
ax.set_aspect('equal', adjustable='box')
```python
我们看到我们随意设置了一个不匹配。带有圆点的橙色菱形与带有圆点的蓝色旋转正方形相关联。我们可以使用 EquivariantOT 来解决这个问题。¶
2.2 初始化基于 Kabsch 的等变 OT 采样器,并采样新的 $(k_0, k_1)$ 对,以在旋转对齐后最小化整个批次的传输成本。我们可以看到,新采样的 $\mathrm{k}_1$ 的顺序已更改以匹配 $\mathrm{k}_0$。请注意,采样的 $\mathrm{k}_1$ 将被旋转,但不会被平移。¶
In [9]
已复制!
# Initialize the Kabsch OT Sampler
kabsch_ot_sampler = EquivariantOTSampler(method="exact", num_threads=1)
# Sample new pairs from the EquivariantOTSampler, mask is not used in this example
# Replace is set to False, so no duplicates are allowed
# Sort is set to "x0", so the order of output x0 is the same as input x0
kabsch_k0, kabsch_k1, mask = kabsch_ot_sampler.apply_ot(
torch.Tensor(k0),
torch.Tensor(k1),
mask=None, replace=False, sort="x0")
# Convert the sampled tensors to numpy arrays
kabsch_k0 = kabsch_k0.numpy()
kabsch_k1 = kabsch_k1.numpy()
```python
In [12]
已复制!
# Plot newly sampled k0 and k1, note that k1 is rotated to match k0
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
for i in range(len(kabsch_k0)):
plot_quadrilateral(kabsch_k0[i], ax, color='C1', marker=markers[i], label='$k_0^%d$'%i)
plot_quadrilateral(kabsch_k1[i], ax, color='C0', marker=markers[i], label='$k_1^%d$'%i)
# Calculate centroids of k0 and k1
# Calculate centroids of k0 and k1
centroid_k0 = np.mean(kabsch_k0[i], axis=0)
centroid_k1 = np.mean(kabsch_k1[i], axis=0)
# Plot a red line connecting the centroids
ax.plot(*zip(centroid_k0, centroid_k1), color='red', linewidth=1, linestyle='--')
ax.legend()
ax.set_aspect('equal', adjustable='box')
```python
如果您想在旋转和平移方面对齐,您可以将数据居中,或增强 EquivariantOT 对象¶
In [ ]
已复制!