构建用于连续数据的生成模型:通过连续插值器¶
在 [1] 中
已复制!
import math
import os
import time
import matplotlib.pyplot as plt
import numpy as np
import torch
from sklearn.datasets import make_moons
import math import os import time import matplotlib.pyplot as plt import numpy as np import torch from sklearn.datasets import make_moons
任务设置¶
为了演示条件流匹配的工作原理,我们使用 sklearn 从自定义 2D 分布中采样并创建自定义 2D 分布。
首先,我们定义我们的“数据加载器”。 这是 '''sample_moons''' 函数。
接下来,我们定义一个自定义 PriorDistribution,以实现将 8 个等距高斯分布转换为上面的月亮分布。
在 [2] 中
已复制!
def sample_moons(n, normalize = False):
x1, _ = make_moons(n_samples=n, noise=0.08)
x1 = torch.Tensor(x1)
x1 = x1 * 3 - 1
if normalize:
x1 = (x1 - x1.mean(0))/x1.std(0) * 2
return x1
def sample_moons(n, normalize = False): x1, _ = make_moons(n_samples=n, noise=0.08) x1 = torch.Tensor(x1) x1 = x1 * 3 - 1 if normalize: x1 = (x1 - x1.mean(0))/x1.std(0) * 2 return x1
在 [3] 中
已复制!
x1 = sample_moons(1000)
plt.scatter(x1[:, 0], x1[:, 1])
x1 = sample_moons(1000) plt.scatter(x1[:, 0], x1[:, 1])
Out[3]
<matplotlib.collections.PathCollection at 0x7eb0d639ca90>
模型创建¶
这里我们定义一个简单的 4 层 MLP 并定义我们的优化器
在 [4] 中
已复制!
dim = 2
hidden_size = 64
batch_size = 256
model = torch.nn.Sequential(
torch.nn.Linear(dim + 1, hidden_size),
torch.nn.SELU(),
torch.nn.Linear(hidden_size, hidden_size),
torch.nn.SELU(),
torch.nn.Linear(hidden_size, hidden_size),
torch.nn.SELU(),
torch.nn.Linear(hidden_size, dim),
)
optimizer = torch.optim.Adam(model.parameters())
dim = 2 hidden_size = 64 batch_size = 256 model = torch.nn.Sequential( torch.nn.Linear(dim + 1, hidden_size), torch.nn.SELU(), torch.nn.Linear(hidden_size, hidden_size), torch.nn.SELU(), torch.nn.Linear(hidden_size, hidden_size), torch.nn.SELU(), torch.nn.Linear(hidden_size, dim), ) optimizer = torch.optim.Adam(model.parameters())
在 [5] 中
已复制!
from bionemo.moco.interpolants import ContinuousFlowMatcher
from bionemo.moco.distributions.time import UniformTimeDistribution
from bionemo.moco.distributions.prior import GaussianPrior
uniform_time = UniformTimeDistribution()
simple_prior = GaussianPrior()
sigma = 0.1
cfm = ContinuousFlowMatcher(time_distribution=uniform_time,
prior_distribution=simple_prior,
sigma=sigma,
prediction_type="velocity")
# Place both the model and the interpolant on the same device
DEVICE = "cuda"
model = model.to(DEVICE)
cfm = cfm.to_device(DEVICE)
from bionemo.moco.interpolants import ContinuousFlowMatcher from bionemo.moco.distributions.time import UniformTimeDistribution from bionemo.moco.distributions.prior import GaussianPrior uniform_time = UniformTimeDistribution() simple_prior = GaussianPrior() sigma = 0.1 cfm = ContinuousFlowMatcher(time_distribution=uniform_time, prior_distribution=simple_prior, sigma=sigma, prediction_type="velocity") # Place both the model and the interpolant on the same device DEVICE = "cuda" model = model.to(DEVICE) cfm = cfm.to_device(DEVICE)
训练循环¶
在 [6] 中
已复制!
for k in range(20000):
optimizer.zero_grad()
shape = (batch_size, dim)
x0 = cfm.sample_prior(shape).to(DEVICE)
x1 = sample_moons(batch_size).to(DEVICE)
t = cfm.sample_time(batch_size)
xt = cfm.interpolate(x1, t, x0)
ut = cfm.calculate_target(x1, x0)
vt = model(torch.cat([xt, t[:, None]], dim=-1))
loss = cfm.loss(vt, ut, target_type="velocity").mean()
loss.backward()
optimizer.step()
if (k + 1) % 5000 == 0:
print(f"{k+1}: loss {loss.item():0.3f}")
for k in range(20000): optimizer.zero_grad() shape = (batch_size, dim) x0 = cfm.sample_prior(shape).to(DEVICE) x1 = sample_moons(batch_size).to(DEVICE) t = cfm.sample_time(batch_size) xt = cfm.interpolate(x1, t, x0) ut = cfm.calculate_target(x1, x0) vt = model(torch.cat([xt, t[:, None]], dim=-1)) loss = cfm.loss(vt, ut, target_type="velocity").mean() loss.backward() optimizer.step() if (k + 1) % 5000 == 0: print(f"{k+1}: loss {loss.item():0.3f}")
5000: loss 2.752 10000: loss 2.838 15000: loss 2.709 20000: loss 3.096
设置生成¶
现在我们需要导入期望的推理时间计划。 这就是为我们提供迭代遍历以从我们的模型迭代生成的时间值。
这里我们展示了输出时间计划以及时间点之间的离散化。 我们注意到,不同的推理时间计划可能具有不同的形状,从而导致非均匀 dt
在 [7] 中
已复制!
from bionemo.moco.schedules.inference_time_schedules import LinearInferenceSchedule
inference_sched = LinearInferenceSchedule(nsteps = 100)
schedule = inference_sched.generate_schedule().to(DEVICE)
dts = inference_sched.discretize().to(DEVICE)
schedule, dts
from bionemo.moco.schedules.inference_time_schedules import LinearInferenceSchedule inference_sched = LinearInferenceSchedule(nsteps = 100) schedule = inference_sched.generate_schedule().to(DEVICE) dts = inference_sched.discretize().to(DEVICE) schedule, dts
Out[7]
(tensor([0.0000, 0.0100, 0.0200, 0.0300, 0.0400, 0.0500, 0.0600, 0.0700, 0.0800, 0.0900, 0.1000, 0.1100, 0.1200, 0.1300, 0.1400, 0.1500, 0.1600, 0.1700, 0.1800, 0.1900, 0.2000, 0.2100, 0.2200, 0.2300, 0.2400, 0.2500, 0.2600, 0.2700, 0.2800, 0.2900, 0.3000, 0.3100, 0.3200, 0.3300, 0.3400, 0.3500, 0.3600, 0.3700, 0.3800, 0.3900, 0.4000, 0.4100, 0.4200, 0.4300, 0.4400, 0.4500, 0.4600, 0.4700, 0.4800, 0.4900, 0.5000, 0.5100, 0.5200, 0.5300, 0.5400, 0.5500, 0.5600, 0.5700, 0.5800, 0.5900, 0.6000, 0.6100, 0.6200, 0.6300, 0.6400, 0.6500, 0.6600, 0.6700, 0.6800, 0.6900, 0.7000, 0.7100, 0.7200, 0.7300, 0.7400, 0.7500, 0.7600, 0.7700, 0.7800, 0.7900, 0.8000, 0.8100, 0.8200, 0.8300, 0.8400, 0.8500, 0.8600, 0.8700, 0.8800, 0.8900, 0.9000, 0.9100, 0.9200, 0.9300, 0.9400, 0.9500, 0.9600, 0.9700, 0.9800, 0.9900], device='cuda:0'), tensor([0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100], device='cuda:0'))
从训练好的模型中采样¶
在 [8] 中
已复制!
inf_size = 1024
sample = cfm.sample_prior((inf_size, 2)).to(DEVICE) # Start with noise
trajectory = [sample]
for dt, t in zip(dts, schedule):
full_t = inference_sched.pad_time(inf_size, t, DEVICE)
vt = model(torch.cat([sample, full_t[:, None]], dim=-1)) # calculate the vector field based on the definition of the model
sample = cfm.step(vt, sample, dt, full_t)
trajectory.append(sample) # save the trajectory for plotting purposes
inf_size = 1024 sample = cfm.sample_prior((inf_size, 2)).to(DEVICE) # 从噪声开始 trajectory = [sample] for dt, t in zip(dts, schedule): full_t = inference_sched.pad_time(inf_size, t, DEVICE) vt = model(torch.cat([sample, full_t[:, None]], dim=-1)) # 根据模型定义计算矢量场 sample = cfm.step(vt, sample, dt, full_t) trajectory.append(sample) # 保存轨迹以进行绘图
在 [9] 中
已复制!
import matplotlib.pyplot as plt
traj = torch.stack(trajectory).cpu().detach().numpy()
n = 2000
# Assuming traj is your tensor and traj.shape = (N, 2000, 2)
# where N is the number of time points, 2000 is the number of samples at each time point, and 2 is for the x and y coordinates.
plt.figure(figsize=(6, 6))
# Plot the first time point in black
plt.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c="black", label='Prior sample z(S)')
# Plot all the rest of the time points except the first and last in olive
for i in range(1, traj.shape[0]-1):
plt.scatter(traj[i, :n, 0], traj[i, :n, 1], s=0.2, alpha=0.2, c="olive")
# Plot the last time point in blue
plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c="blue", label='z(0)')
# Add a second legend for "Flow" since we can't label in the loop directly
plt.scatter([], [], s=0.2, alpha=0.2, c="olive", label='Flow')
plt.legend()
plt.xticks([])
plt.yticks([])
plt.show()
import matplotlib.pyplot as plt traj = torch.stack(trajectory).cpu().detach().numpy() n = 2000 # 假设 traj 是您的张量,并且 traj.shape = (N, 2000, 2) # 其中 N 是时间点数,2000 是每个时间点的样本数,而 2 是 x 和 y 坐标。 plt.figure(figsize=(6, 6)) # 以黑色绘制第一个时间点 plt.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c="black", label='先验样本 z(S)') # 以橄榄色绘制除第一个和最后一个时间点之外的所有其余时间点 for i in range(1, traj.shape[0]-1): plt.scatter(traj[i, :n, 0], traj[i, :n, 1], s=0.2, alpha=0.2, c="olive") # 以蓝色绘制最后一个时间点 plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c="blue", label='z(0)') # 为“流”添加第二个图例,因为我们无法直接在循环中标记 plt.scatter([], [], s=0.2, alpha=0.2, c="olive", label='流') plt.legend() plt.xticks([]) plt.yticks([]) plt.show()
在 [39] 中
已复制!
inf_size = 1024
sample = cfm.sample_prior((inf_size, 2)).to(DEVICE)
trajectory_stoch = [sample]
vts = []
for dt, t in zip(dts, schedule):
time = inference_sched.pad_time(inf_size, t, DEVICE) #torch.full((inf_size,), t).to(DEVICE)
vt = model(torch.cat([sample, time[:, None]], dim=-1))
sample = cfm.step_score_stochastic(vt, sample, dt, time, noise_temperature=1.0, gt_mode = "tan")
trajectory_stoch.append(sample)
vts.append(vt)
inf_size = 1024 sample = cfm.sample_prior((inf_size, 2)).to(DEVICE) trajectory_stoch = [sample] vts = [] for dt, t in zip(dts, schedule): time = inference_sched.pad_time(inf_size, t, DEVICE) #torch.full((inf_size,), t).to(DEVICE) vt = model(torch.cat([sample, time[:, None]], dim=-1)) sample = cfm.step_score_stochastic(vt, sample, dt, time, noise_temperature=1.0, gt_mode = "tan") trajectory_stoch.append(sample) vts.append(vt)
在 [40] 中
已复制!
traj = torch.stack(trajectory_stoch).cpu().detach().numpy()
n = 2000
# Assuming traj is your tensor and traj.shape = (N, 2000, 2)
# where N is the number of time points, 2000 is the number of samples at each time point, and 2 is for the x and y coordinates.
plt.figure(figsize=(6, 6))
# Plot the first time point in black
plt.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c="black", label='Prior sample z(0)')
# Plot all the rest of the time points except the first and last in olive
for i in range(1, traj.shape[0]-1):
plt.scatter(traj[i, :n, 0], traj[i, :n, 1], s=0.2, alpha=0.2, c="olive")
#for i in range(0, traj.shape[0]-1):
# plt.plot(traj[i, :n, 0], traj[i, :n, 1], c="olive", alpha=0.2) #, s=0.2, alpha=0.2, c="olive")
# Plot the last time point in blue
plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c="blue", label='z(1)')
# Add a second legend for "Flow" since we can't label in the loop directly
plt.scatter([], [], s=0.2, alpha=0.2, c="olive", label='Flow')
plt.legend()
plt.xticks([])
plt.yticks([])
plt.title("Stochastic score sampling Temperature = 1.0")
plt.show()
traj = torch.stack(trajectory_stoch).cpu().detach().numpy() n = 2000 # 假设 traj 是您的张量,并且 traj.shape = (N, 2000, 2) # 其中 N 是时间点数,2000 是每个时间点的样本数,而 2 是 x 和 y 坐标。 plt.figure(figsize=(6, 6)) # 以黑色绘制第一个时间点 plt.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c="black", label='先验样本 z(0)') # 以橄榄色绘制除第一个和最后一个时间点之外的所有其余时间点 for i in range(1, traj.shape[0]-1): plt.scatter(traj[i, :n, 0], traj[i, :n, 1], s=0.2, alpha=0.2, c="olive") #for i in range(0, traj.shape[0]-1): # plt.plot(traj[i, :n, 0], traj[i, :n, 1], c="olive", alpha=0.2) #, s=0.2, alpha=0.2, c="olive") # 以蓝色绘制最后一个时间点 plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c="blue", label='z(1)') # 为“流”添加第二个图例,因为我们无法直接在循环中标记 plt.scatter([], [], s=0.2, alpha=0.2, c="olive", label='流') plt.legend() plt.xticks([]) plt.yticks([]) plt.title("随机分数采样 温度 = 1.0") plt.show()
如果您只是从随机模型中采样会发生什么?¶
在 [41] 中
已复制!
fmodel = torch.nn.Sequential(
torch.nn.Linear(dim + 1, hidden_size),
torch.nn.SELU(),
torch.nn.Linear(hidden_size, hidden_size),
torch.nn.SELU(),
torch.nn.Linear(hidden_size, hidden_size),
torch.nn.SELU(),
torch.nn.Linear(hidden_size, dim),
).to(DEVICE)
inf_size = 1024
sample = cfm.sample_prior((inf_size, 2)).to(DEVICE)
trajectory2 = [sample]
for dt, t in zip(dts, schedule):
time = inference_sched.pad_time(inf_size, t, DEVICE) #torch.full((inf_size,), t).to(DEVICE)
vt = fmodel(torch.cat([sample, time[:, None]], dim=-1))
sample = cfm.step(vt, sample, dt, time)
trajectory2.append(sample)
fmodel = torch.nn.Sequential( torch.nn.Linear(dim + 1, hidden_size), torch.nn.SELU(), torch.nn.Linear(hidden_size, hidden_size), torch.nn.SELU(), torch.nn.Linear(hidden_size, hidden_size), torch.nn.SELU(), torch.nn.Linear(hidden_size, dim), ).to(DEVICE) inf_size = 1024 sample = cfm.sample_prior((inf_size, 2)).to(DEVICE) trajectory2 = [sample] for dt, t in zip(dts, schedule): time = inference_sched.pad_time(inf_size, t, DEVICE) #torch.full((inf_size,), t).to(DEVICE) vt = fmodel(torch.cat([sample, time[:, None]], dim=-1)) sample = cfm.step(vt, sample, dt, time) trajectory2.append(sample)
在 [42] 中
已复制!
n = 2000
traj = torch.stack(trajectory2).cpu().detach().numpy()
# Assuming traj is your tensor and traj.shape = (N, 2000, 2)
# where N is the number of time points, 2000 is the number of samples at each time point, and 2 is for the x and y coordinates.
plt.figure(figsize=(6, 6))
# Plot the first time point in black
plt.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c="black", label='Prior sample z(0)')
# Plot all the rest of the time points except the first and last in olive
for i in range(1, traj.shape[0]-1):
plt.scatter(traj[i, :n, 0], traj[i, :n, 1], s=0.2, alpha=0.2, c="olive")
# Plot the last time point in blue
plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c="blue", label='z(1)')
# Add a second legend for "Flow" since we can't label in the loop directly
plt.scatter([], [], s=0.2, alpha=0.2, c="olive", label='Flow')
plt.legend()
plt.xticks([])
plt.yticks([])
plt.show()
n = 2000 traj = torch.stack(trajectory2).cpu().detach().numpy() # 假设 traj 是您的张量,并且 traj.shape = (N, 2000, 2) # 其中 N 是时间点数,2000 是每个时间点的样本数,而 2 是 x 和 y 坐标。 plt.figure(figsize=(6, 6)) # 以黑色绘制第一个时间点 plt.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c="black", label='先验样本 z(0)') # 以橄榄色绘制除第一个和最后一个时间点之外的所有其余时间点 for i in range(1, traj.shape[0]-1): plt.scatter(traj[i, :n, 0], traj[i, :n, 1], s=0.2, alpha=0.2, c="olive") # 以蓝色绘制最后一个时间点 plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c="blue", label='z(1)') # 为“流”添加第二个图例,因为我们无法直接在循环中标记 plt.scatter([], [], s=0.2, alpha=0.2, c="olive", label='流') plt.legend() plt.xticks([]) plt.yticks([]) plt.show()
在 [43] 中
已复制!
import math
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
class Network(nn.Module):
def __init__(
self, dim_in: int, dim_out: int, dim_hids: List[int],
):
super().__init__()
self.layers = nn.ModuleList([
TimeLinear(dim_in, dim_hids[0]),
*[TimeLinear(dim_hids[i-1], dim_hids[i]) for i in range(1, len(dim_hids))],
TimeLinear(dim_hids[-1], dim_out)
])
def forward(self, x: torch.Tensor, t: torch.Tensor):
for i, layer in enumerate(self.layers):
x = layer(x, t)
if i < len(self.layers) - 1:
x = F.relu(x)
return x
class TimeLinear(nn.Module):
def __init__(self, dim_in: int, dim_out: int):
super().__init__()
self.dim_in = dim_in
self.dim_out = dim_out
self.time_embedding = TimeEmbedding(dim_out)
self.fc = nn.Linear(dim_in, dim_out)
def forward(self, x: torch.Tensor, t: torch.Tensor):
x = self.fc(x)
alpha = self.time_embedding(t).view(-1, self.dim_out)
return alpha * x
class TimeEmbedding(nn.Module):
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
half = dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32)
/ half
).to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
return embedding
def forward(self, t: torch.Tensor):
if t.ndim == 0:
t = t.unsqueeze(-1)
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
import math from typing import List import torch import torch.nn as nn import torch.nn.functional as F class Network(nn.Module): def __init__( self, dim_in: int, dim_out: int, dim_hids: List[int], ): super().__init__() self.layers = nn.ModuleList([ TimeLinear(dim_in, dim_hids[0]), *[TimeLinear(dim_hids[i-1], dim_hids[i]) for i in range(1, len(dim_hids))], TimeLinear(dim_hids[-1], dim_out) ]) def forward(self, x: torch.Tensor, t: torch.Tensor): for i, layer in enumerate(self.layers): x = layer(x, t) if i < len(self.layers) - 1: x = F.relu(x) return x class TimeLinear(nn.Module): def __init__(self, dim_in: int, dim_out: int): super().__init__() self.dim_in = dim_in self.dim_out = dim_out self.time_embedding = TimeEmbedding(dim_out) self.fc = nn.Linear(dim_in, dim_out) def forward(self, x: torch.Tensor, t: torch.Tensor): x = self.fc(x) alpha = self.time_embedding(t).view(-1, self.dim_out) return alpha * x class TimeEmbedding(nn.Module): # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py def __init__(self, hidden_size, frequency_embedding_size=256): super().__init__() self.mlp = nn.Sequential( nn.Linear(frequency_embedding_size, hidden_size, bias=True), nn.SiLU(), nn.Linear(hidden_size, hidden_size, bias=True), ) self.frequency_embedding_size = frequency_embedding_size @staticmethod def timestep_embedding(t, dim, max_period=10000): """ Create sinusoidal timestep embeddings. :param t: a 1-D Tensor of N indices, one per batch element. These may be fractional. :param dim: the dimension of the output. :param max_period: controls the minimum frequency of the embeddings. :return: an (N, D) Tensor of positional embeddings. """ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half ).to(device=t.device) args = t[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat( [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 ) return embedding def forward(self, t: torch.Tensor): if t.ndim == 0: t = t.unsqueeze(-1) t_freq = self.timestep_embedding(t, self.frequency_embedding_size) t_emb = self.mlp(t_freq) return t_emb
在 [44] 中
已复制!
from bionemo.moco.distributions.time import UniformTimeDistribution
from bionemo.moco.interpolants import DDPM
from bionemo.moco.schedules.noise.discrete_noise_schedules import DiscreteCosineNoiseSchedule, DiscreteLinearNoiseSchedule
from bionemo.moco.schedules.inference_time_schedules import DiscreteLinearInferenceSchedule
from bionemo.moco.distributions.prior import GaussianPrior
DEVICE = "cuda:0"
uniform_time = UniformTimeDistribution(discrete_time=True, nsteps = 1000)
simple_prior = GaussianPrior()
ddpm = DDPM(time_distribution=uniform_time,
prior_distribution=simple_prior,
prediction_type = "noise",
noise_schedule = DiscreteLinearNoiseSchedule(nsteps = 1000),
device=DEVICE)
from bionemo.moco.distributions.time import UniformTimeDistribution from bionemo.moco.interpolants import DDPM from bionemo.moco.schedules.noise.discrete_noise_schedules import DiscreteCosineNoiseSchedule, DiscreteLinearNoiseSchedule from bionemo.moco.schedules.inference_time_schedules import DiscreteLinearInferenceSchedule from bionemo.moco.distributions.prior import GaussianPrior DEVICE = "cuda:0" uniform_time = UniformTimeDistribution(discrete_time=True, nsteps = 1000) simple_prior = GaussianPrior() ddpm = DDPM(time_distribution=uniform_time, prior_distribution=simple_prior, prediction_type = "noise", noise_schedule = DiscreteLinearNoiseSchedule(nsteps = 1000), device=DEVICE)
训练模型¶
在 [45] 中
已复制!
# Place both the model and the interpolant on the same device
dim = 2
hidden_size = 128
num_hiddens = 3
batch_size = 256
model = Network(dim_in=dim,
dim_out=dim,
dim_hids=[hidden_size]*num_hiddens)
optimizer = torch.optim.Adam(model.parameters(), lr = 1.e-3)
DEVICE = "cuda"
model = model.to(DEVICE)
ddpm = ddpm.to_device(DEVICE)
for k in range(20000):
optimizer.zero_grad()
shape = (batch_size, dim)
x0 = ddpm.sample_prior(shape).to(DEVICE)
x1 = sample_moons(batch_size).to(DEVICE)
t = ddpm.sample_time(batch_size)
xt = ddpm.interpolate(x1, t, x0)
eps = model(xt, t)
loss = ddpm.loss(eps, x0, t).mean()
loss.backward()
optimizer.step()
if (k + 1) % 1000 == 0:
print(f"{k+1}: loss {loss.item():0.3f}")
# 将模型和插值器都放在同一设备上 dim = 2 hidden_size = 128 num_hiddens = 3 batch_size = 256 model = Network(dim_in=dim, dim_out=dim, dim_hids=[hidden_size]*num_hiddens) optimizer = torch.optim.Adam(model.parameters(), lr = 1.e-3) DEVICE = "cuda" model = model.to(DEVICE) ddpm = ddpm.to_device(DEVICE) for k in range(20000): optimizer.zero_grad() shape = (batch_size, dim) x0 = ddpm.sample_prior(shape).to(DEVICE) x1 = sample_moons(batch_size).to(DEVICE) t = ddpm.sample_time(batch_size) xt = ddpm.interpolate(x1, t, x0) eps = model(xt, t) loss = ddpm.loss(eps, x0, t).mean() loss.backward() optimizer.step() if (k + 1) % 1000 == 0: print(f"{k+1}: loss {loss.item():0.3f}")
1000: loss 0.320 2000: loss 0.372 3000: loss 0.330 4000: loss 0.409 5000: loss 0.338 6000: loss 0.378 7000: loss 0.355 8000: loss 0.394 9000: loss 0.359 10000: loss 0.338 11000: loss 0.257 12000: loss 0.293 13000: loss 0.333 14000: loss 0.329 15000: loss 0.322 16000: loss 0.302 17000: loss 0.282 18000: loss 0.331 19000: loss 0.289 20000: loss 0.322
让我们可视化训练期间不同时间的插值外观¶
在 [46] 中
已复制!
x0 = ddpm.sample_prior(shape).to(DEVICE)
x1 = sample_moons(batch_size).to(DEVICE)
for t in range(0, 900, 100):
tt = ddpm.sample_time(batch_size)*0 + t
out = ddpm.interpolate(x1, tt, x0)
plt.scatter(out[:, 0].cpu().detach(), out[:, 1].cpu().detach())
plt.title(f"Time = {t}")
plt.show()
x0 = ddpm.sample_prior(shape).to(DEVICE) x1 = sample_moons(batch_size).to(DEVICE) for t in range(0, 900, 100): tt = ddpm.sample_time(batch_size)*0 + t out = ddpm.interpolate(x1, tt, x0) plt.scatter(out[:, 0].cpu().detach(), out[:, 1].cpu().detach()) plt.title(f"时间 = {t}") plt.show()
创建推理时间计划并从模型中采样¶
在 [47] 中
已复制!
inf_size = 1024
schedule = DiscreteLinearInferenceSchedule(nsteps = 1000, direction = "diffusion").generate_schedule(device= DEVICE)
sample = ddpm.sample_prior((inf_size, 2)).to(DEVICE) # Start with noise
trajectory = [sample]
for t in schedule:
full_t = torch.full((inf_size,), t).to(DEVICE)
vt = model(sample, full_t) # calculate the vector field based on the definition of the model
sample = ddpm.step_noise(vt, full_t, sample)
trajectory.append(sample) # save the trajectory for plotting purposes
inf_size = 1024 schedule = DiscreteLinearInferenceSchedule(nsteps = 1000, direction = "diffusion").generate_schedule(device= DEVICE) sample = ddpm.sample_prior((inf_size, 2)).to(DEVICE) # 从噪声开始 trajectory = [sample] for t in schedule: full_t = torch.full((inf_size,), t).to(DEVICE) vt = model(sample, full_t) # 根据模型定义计算矢量场 sample = ddpm.step_noise(vt, full_t, sample) trajectory.append(sample) # 保存轨迹以进行绘图
在 [48] 中
已复制!
import matplotlib.pyplot as plt
traj = torch.stack(trajectory).cpu().detach().numpy()
n = 2000
# Assuming traj is your tensor and traj.shape = (N, 2000, 2)
# where N is the number of time points, 2000 is the number of samples at each time point, and 2 is for the x and y coordinates.
plt.figure(figsize=(6, 6))
# Plot the first time point in black
plt.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c="black", label='Prior sample z(S)')
# Plot all the rest of the time points except the first and last in olive
for i in range(1, traj.shape[0]-1):
plt.scatter(traj[i, :n, 0], traj[i, :n, 1], s=0.2, alpha=0.2, c="olive")
# Plot the last time point in blue
plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c="blue", label='z(0)')
# Add a second legend for "Flow" since we can't label in the loop directly
plt.scatter([], [], s=0.2, alpha=0.2, c="olive", label='Flow')
plt.legend()
plt.xticks([])
plt.yticks([])
plt.show()
import matplotlib.pyplot as plt traj = torch.stack(trajectory).cpu().detach().numpy() n = 2000 # 假设 traj 是您的张量,并且 traj.shape = (N, 2000, 2) # 其中 N 是时间点数,2000 是每个时间点的样本数,而 2 是 x 和 y 坐标。 plt.figure(figsize=(6, 6)) # 以黑色绘制第一个时间点 plt.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c="black", label='先验样本 z(S)') # 以橄榄色绘制除第一个和最后一个时间点之外的所有其余时间点 for i in range(1, traj.shape[0]-1): plt.scatter(traj[i, :n, 0], traj[i, :n, 1], s=0.2, alpha=0.2, c="olive") # 以蓝色绘制最后一个时间点 plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c="blue", label='z(0)') # 为“流”添加第二个图例,因为我们无法直接在循环中标记 plt.scatter([], [], s=0.2, alpha=0.2, c="olive", label='流') plt.legend() plt.xticks([]) plt.yticks([]) plt.show()
/home/dreidenbach/mambaforge/envs/moco_bionemo/lib/python3.10/site-packages/IPython/core/pylabtools.py:170: UserWarning: Creating legend with loc="best" can be slow with large amounts of data. fig.canvas.print_figure(bytes_io, **kw)
在 [49] 中
已复制!
inf_size = 1024
sample = ddpm.sample_prior((inf_size, 2)).to(DEVICE) # Start with noise
trajectory = [sample]
for t in schedule:
full_t = torch.full((inf_size,), t).to(DEVICE)
eps_hat = model(sample, full_t) # calculate the vector field based on the definition of the model
sample = ddpm.step(eps_hat, full_t, sample)
trajectory.append(sample) # save the trajectory for plotting purposes
inf_size = 1024 sample = ddpm.sample_prior((inf_size, 2)).to(DEVICE) # 从噪声开始 trajectory = [sample] for t in schedule: full_t = torch.full((inf_size,), t).to(DEVICE) eps_hat = model(sample, full_t) # 根据模型定义计算矢量场 sample = ddpm.step(eps_hat, full_t, sample) trajectory.append(sample) # 保存轨迹以进行绘图
在 [50] 中
已复制!
traj = torch.stack(trajectory).cpu().detach().numpy()
n = 2000
# Assuming traj is your tensor and traj.shape = (N, 2000, 2)
# where N is the number of time points, 2000 is the number of samples at each time point, and 2 is for the x and y coordinates.
plt.figure(figsize=(6, 6))
# Plot the first time point in black
plt.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c="black", label='Prior sample z(S)')
# Plot all the rest of the time points except the first and last in olive
for i in range(1, traj.shape[0]-1):
plt.scatter(traj[i, :n, 0], traj[i, :n, 1], s=0.2, alpha=0.2, c="olive")
# Plot the last time point in blue
plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c="blue", label='z(0)')
# Add a second legend for "Flow" since we can't label in the loop directly
plt.scatter([], [], s=0.2, alpha=0.2, c="olive", label='Flow')
plt.legend()
plt.xticks([])
plt.yticks([])
plt.show()
traj = torch.stack(trajectory).cpu().detach().numpy() n = 2000 # 假设 traj 是您的张量,并且 traj.shape = (N, 2000, 2) # 其中 N 是时间点数,2000 是每个时间点的样本数,而 2 是 x 和 y 坐标。 plt.figure(figsize=(6, 6)) # 以黑色绘制第一个时间点 plt.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c="black", label='先验样本 z(S)') # 以橄榄色绘制除第一个和最后一个时间点之外的所有其余时间点 for i in range(1, traj.shape[0]-1): plt.scatter(traj[i, :n, 0], traj[i, :n, 1], s=0.2, alpha=0.2, c="olive") # 以蓝色绘制最后一个时间点 plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c="blue", label='z(0)') # 为“流”添加第二个图例,因为我们无法直接在循环中标记 plt.scatter([], [], s=0.2, alpha=0.2, c="olive", label='流') plt.legend() plt.xticks([]) plt.yticks([]) plt.show()
请注意,这会产生与在基于随机分数的 CFM 示例中使用底层分数函数非常相似的结果¶
请注意,无论是否在 .step() 函数内部将预测的噪声转换为数据,都没有区别¶
让我们尝试其他很酷的采样函数¶
在 [51] 中
已复制!
inf_size = 1024
schedule = DiscreteLinearInferenceSchedule(nsteps = 1000, direction = "diffusion").generate_schedule(device= DEVICE)
sample = ddpm.sample_prior((inf_size, 2)).to(DEVICE) # Start with noise
trajectory = [sample]
for t in schedule:
full_t = torch.full((inf_size,), t).to(DEVICE)
eps_hat = model(sample, full_t) # calculate the vector field based on the definition of the model
sample = ddpm.step_ddim(eps_hat, full_t, sample)
trajectory.append(sample) # save the trajectory for plotting purposes
inf_size = 1024 schedule = DiscreteLinearInferenceSchedule(nsteps = 1000, direction = "diffusion").generate_schedule(device= DEVICE) sample = ddpm.sample_prior((inf_size, 2)).to(DEVICE) # 从噪声开始 trajectory = [sample] for t in schedule: full_t = torch.full((inf_size,), t).to(DEVICE) eps_hat = model(sample, full_t) # 根据模型定义计算矢量场 sample = ddpm.step_ddim(eps_hat, full_t, sample) trajectory.append(sample) # 保存轨迹以进行绘图
在 [52] 中
已复制!
traj = torch.stack(trajectory).cpu().detach().numpy()
n = 2000
# Assuming traj is your tensor and traj.shape = (N, 2000, 2)
# where N is the number of time points, 2000 is the number of samples at each time point, and 2 is for the x and y coordinates.
plt.figure(figsize=(6, 6))
# Plot the first time point in black
plt.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c="black", label='Prior sample z(S)')
# Plot all the rest of the time points except the first and last in olive
for i in range(1, traj.shape[0]-1):
plt.scatter(traj[i, :n, 0], traj[i, :n, 1], s=0.2, alpha=0.2, c="olive")
# Plot the last time point in blue
plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c="blue", label='z(0)')
# Add a second legend for "Flow" since we can't label in the loop directly
plt.scatter([], [], s=0.2, alpha=0.2, c="olive", label='Flow')
plt.legend()
plt.xticks([])
plt.yticks([])
plt.show()
traj = torch.stack(trajectory).cpu().detach().numpy() n = 2000 # 假设 traj 是您的张量,并且 traj.shape = (N, 2000, 2) # 其中 N 是时间点数,2000 是每个时间点的样本数,而 2 是 x 和 y 坐标。 plt.figure(figsize=(6, 6)) # 以黑色绘制第一个时间点 plt.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c="black", label='先验样本 z(S)') # 以橄榄色绘制除第一个和最后一个时间点之外的所有其余时间点 for i in range(1, traj.shape[0]-1): plt.scatter(traj[i, :n, 0], traj[i, :n, 1], s=0.2, alpha=0.2, c="olive") # 以蓝色绘制最后一个时间点 plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c="blue", label='z(0)') # 为“流”添加第二个图例,因为我们无法直接在循环中标记 plt.scatter([], [], s=0.2, alpha=0.2, c="olive", label='流') plt.legend() plt.xticks([]) plt.yticks([]) plt.show()
当您使用 DDPM 从未经训练的模型中采样时会发生什么¶
在 [53] 中
已复制!
model = Network(dim_in=dim,
dim_out=dim,
dim_hids=[hidden_size]*num_hiddens).to(DEVICE)
inf_size = 1024
sample = ddpm.sample_prior((inf_size, 2)).to(DEVICE)
trajectory2 = [sample]
for t in schedule:
full_t = torch.full((inf_size,), t).to(DEVICE)
vt = model(sample, full_t) # calculate the vector field based on the definition of the model
sample = ddpm.step_noise(vt, full_t, sample)
trajectory2.append(sample) #
n = 2000
traj = torch.stack(trajectory2).cpu().detach().numpy()
# Assuming traj is your tensor and traj.shape = (N, 2000, 2)
# where N is the number of time points, 2000 is the number of samples at each time point, and 2 is for the x and y coordinates.
plt.figure(figsize=(6, 6))
# Plot the first time point in black
plt.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c="black", label='Prior sample z(0)')
# Plot all the rest of the time points except the first and last in olive
for i in range(1, traj.shape[0]-1):
plt.scatter(traj[i, :n, 0], traj[i, :n, 1], s=0.2, alpha=0.2, c="olive")
# Plot the last time point in blue
plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c="blue", label='z(1)')
# Add a second legend for "Flow" since we can't label in the loop directly
plt.scatter([], [], s=0.2, alpha=0.2, c="olive", label='Flow')
plt.legend()
plt.xticks([])
plt.yticks([])
plt.show()
model = Network(dim_in=dim, dim_out=dim, dim_hids=[hidden_size]*num_hiddens).to(DEVICE) inf_size = 1024 sample = ddpm.sample_prior((inf_size, 2)).to(DEVICE) trajectory2 = [sample] for t in schedule: full_t = torch.full((inf_size,), t).to(DEVICE) vt = model(sample, full_t) # 根据模型定义计算矢量场 sample = ddpm.step_noise(vt, full_t, sample) trajectory2.append(sample) # n = 2000 traj = torch.stack(trajectory2).cpu().detach().numpy() # 假设 traj 是您的张量,并且 traj.shape = (N, 2000, 2) # 其中 N 是时间点数,2000 是每个时间点的样本数,而 2 是 x 和 y 坐标。 plt.figure(figsize=(6, 6)) # 以黑色绘制第一个时间点 plt.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c="black", label='先验样本 z(0)') # 以橄榄色绘制除第一个和最后一个时间点之外的所有其余时间点 for i in range(1, traj.shape[0]-1): plt.scatter(traj[i, :n, 0], traj[i, :n, 1], s=0.2, alpha=0.2, c="olive") # 以蓝色绘制最后一个时间点 plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c="blue", label='z(1)') # 为“流”添加第二个图例,因为我们无法直接在循环中标记 plt.scatter([], [], s=0.2, alpha=0.2, c="olive", label='流') plt.legend() plt.xticks([]) plt.yticks([]) plt.show()
现在让我们将 DDPM 的参数化从噪声切换到数据¶
在这里,我们不是训练模型来学习噪声,而是希望学习原始数据。 这两个选项都是有效的,选择哪个选项取决于底层的建模任务。
在 [54] 中
已复制!
from bionemo.moco.distributions.time.uniform import UniformTimeDistribution
from bionemo.moco.interpolants.discrete_time.continuous.ddpm import DDPM
from bionemo.moco.schedules.noise.discrete_noise_schedules import DiscreteCosineNoiseSchedule, DiscreteLinearNoiseSchedule
from bionemo.moco.schedules.inference_time_schedules import DiscreteLinearInferenceSchedule
from bionemo.moco.distributions.prior.continuous.gaussian import GaussianPrior
DEVICE = "cuda:0"
uniform_time = UniformTimeDistribution(discrete_time=True, nsteps = 1000)
simple_prior = GaussianPrior()
ddpm = DDPM(time_distribution=uniform_time,
prior_distribution=simple_prior,
prediction_type = "data",
noise_schedule = DiscreteLinearNoiseSchedule(nsteps = 1000),
device=DEVICE)
from bionemo.moco.distributions.time.uniform import UniformTimeDistribution from bionemo.moco.interpolants.discrete_time.continuous.ddpm import DDPM from bionemo.moco.schedules.noise.discrete_noise_schedules import DiscreteCosineNoiseSchedule, DiscreteLinearNoiseSchedule from bionemo.moco.schedules.inference_time_schedules import DiscreteLinearInferenceSchedule from bionemo.moco.distributions.prior.continuous.gaussian import GaussianPrior DEVICE = "cuda:0" uniform_time = UniformTimeDistribution(discrete_time=True, nsteps = 1000) simple_prior = GaussianPrior() ddpm = DDPM(time_distribution=uniform_time, prior_distribution=simple_prior, prediction_type = "data", noise_schedule = DiscreteLinearNoiseSchedule(nsteps = 1000), device=DEVICE)
让我们首先训练模型,并赋予权重,使其在理论上等同于简单的噪声匹配损失。 请参阅 https://arxiv.org/pdf/2202.00512 中的公式 9¶
在 [55] 中
已复制!
# Place both the model and the interpolant on the same device
dim = 2
hidden_size = 128
num_hiddens = 3
batch_size = 256
model = Network(dim_in=dim,
dim_out=dim,
dim_hids=[hidden_size]*num_hiddens)
optimizer = torch.optim.Adam(model.parameters(), lr = 1.e-3)
DEVICE = "cuda"
model = model.to(DEVICE)
ddpm = ddpm.to_device(DEVICE)
for k in range(20000):
optimizer.zero_grad()
shape = (batch_size, dim)
x0 = ddpm.sample_prior(shape).to(DEVICE)
x1 = sample_moons(batch_size).to(DEVICE)
t = ddpm.sample_time(batch_size)
xt = ddpm.interpolate(x1, t, x0)
x_hat = model(xt, t)
loss = ddpm.loss(x_hat, x1, t, weight_type="data_to_noise").mean()
loss.backward()
optimizer.step()
if (k + 1) % 1000 == 0:
print(f"{k+1}: loss {loss.item():0.3f}")
# 将模型和插值器都放在同一设备上 dim = 2 hidden_size = 128 num_hiddens = 3 batch_size = 256 model = Network(dim_in=dim, dim_out=dim, dim_hids=[hidden_size]*num_hiddens) optimizer = torch.optim.Adam(model.parameters(), lr = 1.e-3) DEVICE = "cuda" model = model.to(DEVICE) ddpm = ddpm.to_device(DEVICE) for k in range(20000): optimizer.zero_grad() shape = (batch_size, dim) x0 = ddpm.sample_prior(shape).to(DEVICE) x1 = sample_moons(batch_size).to(DEVICE) t = ddpm.sample_time(batch_size) xt = ddpm.interpolate(x1, t, x0) x_hat = model(xt, t) loss = ddpm.loss(x_hat, x1, t, weight_type="data_to_noise").mean() loss.backward() optimizer.step() if (k + 1) % 1000 == 0: print(f"{k+1}: loss {loss.item():0.3f}")
1000: loss 0.504 2000: loss 1.002 3000: loss 0.446 4000: loss 1.014 5000: loss 0.375 6000: loss 1.849 7000: loss 0.489 8000: loss 1.577 9000: loss 0.314 10000: loss 0.468 11000: loss 0.332 12000: loss 1.729 13000: loss 0.374 14000: loss 0.779 15000: loss 0.536 16000: loss 6.597 17000: loss 1.269 18000: loss 0.501 19000: loss 0.546 20000: loss 0.490
在 [56] 中
已复制!
inf_size = 1024
sample = ddpm.sample_prior((inf_size, 2)).to(DEVICE) # Start with noise
trajectory = [sample]
for t in schedule:
full_t = torch.full((inf_size,), t).to(DEVICE)
x_hat = model(sample, full_t) # calculate the vector field based on the definition of the model
sample = ddpm.step(x_hat, full_t, sample)
trajectory.append(sample) # save the trajectory for plotting purposes
inf_size = 1024 sample = ddpm.sample_prior((inf_size, 2)).to(DEVICE) # 从噪声开始 trajectory = [sample] for t in schedule: full_t = torch.full((inf_size,), t).to(DEVICE) x_hat = model(sample, full_t) # 根据模型定义计算矢量场 sample = ddpm.step(x_hat, full_t, sample) trajectory.append(sample) # 保存轨迹以进行绘图
在 [57] 中
已复制!
import matplotlib.pyplot as plt
traj = torch.stack(trajectory).cpu().detach().numpy()
n = 2000
# Assuming traj is your tensor and traj.shape = (N, 2000, 2)
# where N is the number of time points, 2000 is the number of samples at each time point, and 2 is for the x and y coordinates.
plt.figure(figsize=(6, 6))
# Plot the first time point in black
plt.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c="black", label='Prior sample z(S)')
# Plot all the rest of the time points except the first and last in olive
for i in range(1, traj.shape[0]-1):
plt.scatter(traj[i, :n, 0], traj[i, :n, 1], s=0.2, alpha=0.2, c="olive")
# Plot the last time point in blue
plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c="blue", label='z(0)')
# Add a second legend for "Flow" since we can't label in the loop directly
plt.scatter([], [], s=0.2, alpha=0.2, c="olive", label='Flow')
plt.legend()
plt.xticks([])
plt.yticks([])
plt.show()
import matplotlib.pyplot as plt traj = torch.stack(trajectory).cpu().detach().numpy() n = 2000 # 假设 traj 是您的张量,并且 traj.shape = (N, 2000, 2) # 其中 N 是时间点数,2000 是每个时间点的样本数,而 2 是 x 和 y 坐标。 plt.figure(figsize=(6, 6)) # 以黑色绘制第一个时间点 plt.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c="black", label='先验样本 z(S)') # 以橄榄色绘制除第一个和最后一个时间点之外的所有其余时间点 for i in range(1, traj.shape[0]-1): plt.scatter(traj[i, :n, 0], traj[i, :n, 1], s=0.2, alpha=0.2, c="olive") # 以蓝色绘制最后一个时间点 plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c="blue", label='z(0)') # 为“流”添加第二个图例,因为我们无法直接在循环中标记 plt.scatter([], [], s=0.2, alpha=0.2, c="olive", label='流') plt.legend() plt.xticks([]) plt.yticks([]) plt.show()
现在让我们在没有损失权重的情况下进行训练,以优化真正的用于比较的数据匹配损失¶
在 [58] 中
已复制!
# Place both the model and the interpolant on the same device
dim = 2
hidden_size = 128
num_hiddens = 3
batch_size = 256
model = Network(dim_in=dim,
dim_out=dim,
dim_hids=[hidden_size]*num_hiddens)
optimizer = torch.optim.Adam(model.parameters(), lr = 1.e-3)
DEVICE = "cuda"
model = model.to(DEVICE)
ddpm = ddpm.to_device(DEVICE)
for k in range(20000):
optimizer.zero_grad()
shape = (batch_size, dim)
x0 = ddpm.sample_prior(shape).to(DEVICE)
x1 = sample_moons(batch_size).to(DEVICE)
t = ddpm.sample_time(batch_size)
xt = ddpm.interpolate(x1, t, x0)
x_hat = model(xt, t)
loss = ddpm.loss(x_hat, x1, t, weight_type="ones").mean()
loss.backward()
optimizer.step()
if (k + 1) % 1000 == 0:
print(f"{k+1}: loss {loss.item():0.3f}")
# 将模型和插值器都放在同一设备上 dim = 2 hidden_size = 128 num_hiddens = 3 batch_size = 256 model = Network(dim_in=dim, dim_out=dim, dim_hids=[hidden_size]*num_hiddens) optimizer = torch.optim.Adam(model.parameters(), lr = 1.e-3) DEVICE = "cuda" model = model.to(DEVICE) ddpm = ddpm.to_device(DEVICE) for k in range(20000): optimizer.zero_grad() shape = (batch_size, dim) x0 = ddpm.sample_prior(shape).to(DEVICE) x1 = sample_moons(batch_size).to(DEVICE) t = ddpm.sample_time(batch_size) xt = ddpm.interpolate(x1, t, x0) x_hat = model(xt, t) loss = ddpm.loss(x_hat, x1, t, weight_type="ones").mean() loss.backward() optimizer.step() if (k + 1) % 1000 == 0: print(f"{k+1}: loss {loss.item():0.3f}")
1000: loss 2.651 2000: loss 2.659 3000: loss 2.603 4000: loss 2.507 5000: loss 2.650 6000: loss 2.792 7000: loss 2.670 8000: loss 2.550 9000: loss 2.685 10000: loss 2.410 11000: loss 2.290 12000: loss 2.755 13000: loss 2.521 14000: loss 2.505 15000: loss 2.196 16000: loss 2.702 17000: loss 2.933 18000: loss 2.350 19000: loss 2.397 20000: loss 2.382
在 [59] 中
已复制!
inf_size = 1024
sample = ddpm.sample_prior((inf_size, 2)).to(DEVICE) # Start with noise
trajectory = [sample]
for t in schedule:
full_t = torch.full((inf_size,), t).to(DEVICE)
x_hat = model(sample, full_t) # calculate the vector field based on the definition of the model
sample = ddpm.step(x_hat, full_t, sample)
trajectory.append(sample) # save the trajectory for plotting purposes
inf_size = 1024 sample = ddpm.sample_prior((inf_size, 2)).to(DEVICE) # 从噪声开始 trajectory = [sample] for t in schedule: full_t = torch.full((inf_size,), t).to(DEVICE) x_hat = model(sample, full_t) # 根据模型定义计算矢量场 sample = ddpm.step(x_hat, full_t, sample) trajectory.append(sample) # 保存轨迹以进行绘图
在 [60] 中
已复制!
import matplotlib.pyplot as plt
traj = torch.stack(trajectory).cpu().detach().numpy()
n = 2000
# Assuming traj is your tensor and traj.shape = (N, 2000, 2)
# where N is the number of time points, 2000 is the number of samples at each time point, and 2 is for the x and y coordinates.
plt.figure(figsize=(6, 6))
# Plot the first time point in black
plt.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c="black", label='Prior sample z(S)')
# Plot all the rest of the time points except the first and last in olive
for i in range(1, traj.shape[0]-1):
plt.scatter(traj[i, :n, 0], traj[i, :n, 1], s=0.2, alpha=0.2, c="olive")
# Plot the last time point in blue
plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c="blue", label='z(0)')
# Add a second legend for "Flow" since we can't label in the loop directly
plt.scatter([], [], s=0.2, alpha=0.2, c="olive", label='Flow')
plt.legend()
plt.xticks([])
plt.yticks([])
plt.show()
import matplotlib.pyplot as plt traj = torch.stack(trajectory).cpu().detach().numpy() n = 2000 # 假设 traj 是您的张量,并且 traj.shape = (N, 2000, 2) # 其中 N 是时间点数,2000 是每个时间点的样本数,而 2 是 x 和 y 坐标。 plt.figure(figsize=(6, 6)) # 以黑色绘制第一个时间点 plt.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c="black", label='先验样本 z(S)') # 以橄榄色绘制除第一个和最后一个时间点之外的所有其余时间点 for i in range(1, traj.shape[0]-1): plt.scatter(traj[i, :n, 0], traj[i, :n, 1], s=0.2, alpha=0.2, c="olive") # 以蓝色绘制最后一个时间点 plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c="blue", label='z(0)') # 为“流”添加第二个图例,因为我们无法直接在循环中标记 plt.scatter([], [], s=0.2, alpha=0.2, c="olive", label='流') plt.legend() plt.xticks([]) plt.yticks([]) plt.show()
现在让我们尝试一个连续时间模拟插值器到 DDPM,称为 VDM¶
此插值器已在 Chroma 中使用,并在此处详细描述 https://www.biorxiv.org/content/10.1101/2022.12.01.518682v1.full.pdf¶
In [61]
已复制!
from bionemo.moco.distributions.time import UniformTimeDistribution
from bionemo.moco.interpolants import VDM
from bionemo.moco.schedules.noise.continuous_snr_transforms import CosineSNRTransform, LinearSNRTransform, LinearLogInterpolatedSNRTransform
from bionemo.moco.schedules.inference_time_schedules import LinearInferenceSchedule
from bionemo.moco.distributions.prior.continuous.gaussian import GaussianPrior
DEVICE = "cuda:0"
uniform_time = UniformTimeDistribution(discrete_time=False)
simple_prior = GaussianPrior()
vdm = VDM(time_distribution=uniform_time,
prior_distribution=simple_prior,
prediction_type = "data",
noise_schedule = LinearLogInterpolatedSNRTransform(),
device=DEVICE)
schedule = LinearInferenceSchedule(nsteps = 1000, direction="diffusion")
from bionemo.moco.distributions.time import UniformTimeDistribution from bionemo.moco.interpolants import VDM from bionemo.moco.schedules.noise.continuous_snr_transforms import CosineSNRTransform, LinearSNRTransform, LinearLogInterpolatedSNRTransform from bionemo.moco.schedules.inference_time_schedules import LinearInferenceSchedule from bionemo.moco.distributions.prior.continuous.gaussian import GaussianPrior DEVICE = "cuda:0" uniform_time = UniformTimeDistribution(discrete_time=False) simple_prior = GaussianPrior() vdm = VDM(time_distribution=uniform_time, prior_distribution=simple_prior, prediction_type = "data", noise_schedule = LinearLogInterpolatedSNRTransform(), device=DEVICE) schedule = LinearInferenceSchedule(nsteps = 1000, direction="diffusion")
In [62]
已复制!
# Place both the model and the interpolant on the same device
dim = 2
hidden_size = 128
num_hiddens = 3
batch_size = 256
model = Network(dim_in=dim,
dim_out=dim,
dim_hids=[hidden_size]*num_hiddens)
DEVICE = "cuda"
model = model.to(DEVICE)
# Place both the model and the interpolant on the same device dim = 2 hidden_size = 128 num_hiddens = 3 batch_size = 256 model = Network(dim_in=dim, dim_out=dim, dim_hids=[hidden_size]*num_hiddens) DEVICE = "cuda" model = model.to(DEVICE)
In [63]
已复制!
optimizer = torch.optim.Adam(model.parameters(), lr = 1.e-3)
for k in range(20000):
optimizer.zero_grad()
shape = (batch_size, dim)
x0 = vdm.sample_prior(shape).to(DEVICE)
x1 = sample_moons(batch_size).to(DEVICE)
t = vdm.sample_time(batch_size)
xt = vdm.interpolate(x1, t, x0)
x_hat = model(xt, t)
loss = vdm.loss(x_hat, x1, t, weight_type="ones").mean()
loss.backward()
optimizer.step()
if (k + 1) % 1000 == 0:
print(f"{k+1}: loss {loss.item():0.3f}")
optimizer = torch.optim.Adam(model.parameters(), lr = 1.e-3) for k in range(20000): optimizer.zero_grad() shape = (batch_size, dim) x0 = vdm.sample_prior(shape).to(DEVICE) x1 = sample_moons(batch_size).to(DEVICE) t = vdm.sample_time(batch_size) xt = vdm.interpolate(x1, t, x0) x_hat = model(xt, t) loss = vdm.loss(x_hat, x1, t, weight_type="ones").mean() loss.backward() optimizer.step() if (k + 1) % 1000 == 0: print(f"{k+1}: loss {loss.item():0.3f}")
1000: loss 1.251 2000: loss 1.152 3000: loss 1.156 4000: loss 0.908 5000: loss 1.174 6000: loss 1.355 7000: loss 1.008 8000: loss 1.567 9000: loss 1.092 10000: loss 1.290 11000: loss 1.149 12000: loss 1.350 13000: loss 1.480 14000: loss 1.061 15000: loss 1.223 16000: loss 1.180 17000: loss 1.127 18000: loss 1.351 19000: loss 1.059 20000: loss 1.074
In [64]
已复制!
# DEVICE="cuda:1"
# model = model.to(DEVICE)
# vdm = vdm.to_device(DEVICE)
inf_size = 1024
sample = vdm.sample_prior((inf_size, 2)).to(DEVICE) # Start with noise
trajectory = [sample]
ts = schedule.generate_schedule()
dts = schedule.discretize()
for dt, t in zip(dts, ts):
full_t = torch.full((inf_size,), t).to(DEVICE)
x_hat = model(sample, full_t) # calculate the vector field based on the definition of the model
sample = vdm.step(x_hat, full_t, sample, dt)
trajectory.append(sample) # save the trajectory for plotting purposes
# DEVICE="cuda:1" # model = model.to(DEVICE) # vdm = vdm.to_device(DEVICE) inf_size = 1024 sample = vdm.sample_prior((inf_size, 2)).to(DEVICE) # Start with noise trajectory = [sample] ts = schedule.generate_schedule() dts = schedule.discretize() for dt, t in zip(dts, ts): full_t = torch.full((inf_size,), t).to(DEVICE) x_hat = model(sample, full_t) # calculate the vector field based on the definition of the model sample = vdm.step(x_hat, full_t, sample, dt) trajectory.append(sample) # save the trajectory for plotting purposes
In [65]
已复制!
import matplotlib.pyplot as plt
traj = torch.stack(trajectory).cpu().detach().numpy()
n = 2000
# Assuming traj is your tensor and traj.shape = (N, 2000, 2)
# where N is the number of time points, 2000 is the number of samples at each time point, and 2 is for the x and y coordinates.
plt.figure(figsize=(6, 6))
# Plot the first time point in black
plt.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c="black", label='Prior sample z(S)')
# Plot all the rest of the time points except the first and last in olive
for i in range(1, traj.shape[0]-1):
plt.scatter(traj[i, :n, 0], traj[i, :n, 1], s=0.2, alpha=0.2, c="olive")
# Plot the last time point in blue
plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c="blue", label='z(0)')
# Add a second legend for "Flow" since we can't label in the loop directly
plt.scatter([], [], s=0.2, alpha=0.2, c="olive", label='Flow')
plt.legend()
plt.xticks([])
plt.yticks([])
plt.show()
import matplotlib.pyplot as plt traj = torch.stack(trajectory).cpu().detach().numpy() n = 2000 # 假设 traj 是您的张量,并且 traj.shape = (N, 2000, 2) # 其中 N 是时间点数,2000 是每个时间点的样本数,而 2 是 x 和 y 坐标。 plt.figure(figsize=(6, 6)) # 以黑色绘制第一个时间点 plt.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c="black", label='先验样本 z(S)') # 以橄榄色绘制除第一个和最后一个时间点之外的所有其余时间点 for i in range(1, traj.shape[0]-1): plt.scatter(traj[i, :n, 0], traj[i, :n, 1], s=0.2, alpha=0.2, c="olive") # 以蓝色绘制最后一个时间点 plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c="blue", label='z(0)') # 为“流”添加第二个图例,因为我们无法直接在循环中标记 plt.scatter([], [], s=0.2, alpha=0.2, c="olive", label='流') plt.legend() plt.xticks([]) plt.yticks([]) plt.show()
In [66]
已复制!
inf_size = 1024
sample = vdm.sample_prior((inf_size, 2)).to(DEVICE) # Start with noise
trajectory = [sample]
ts = schedule.generate_schedule()
dts = schedule.discretize()
for dt, t in zip(dts, ts):
full_t = torch.full((inf_size,), t).to(DEVICE)
x_hat = model(sample, full_t) # calculate the vector field based on the definition of the model
sample = vdm.step_ddim(x_hat, full_t, sample, dt)
trajectory.append(sample) # save the trajectory for plotting purposes
inf_size = 1024 sample = vdm.sample_prior((inf_size, 2)).to(DEVICE) # Start with noise trajectory = [sample] ts = schedule.generate_schedule() dts = schedule.discretize() for dt, t in zip(dts, ts): full_t = torch.full((inf_size,), t).to(DEVICE) x_hat = model(sample, full_t) # calculate the vector field based on the definition of the model sample = vdm.step_ddim(x_hat, full_t, sample, dt) trajectory.append(sample) # save the trajectory for plotting purposes
In [67]
已复制!
import matplotlib.pyplot as plt
traj = torch.stack(trajectory).cpu().detach().numpy()
n = 2000
# Assuming traj is your tensor and traj.shape = (N, 2000, 2)
# where N is the number of time points, 2000 is the number of samples at each time point, and 2 is for the x and y coordinates.
plt.figure(figsize=(6, 6))
# Plot the first time point in black
plt.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c="black", label='Prior sample z(S)')
# Plot all the rest of the time points except the first and last in olive
for i in range(1, traj.shape[0]-1):
plt.scatter(traj[i, :n, 0], traj[i, :n, 1], s=0.2, alpha=0.2, c="olive")
# Plot the last time point in blue
plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c="blue", label='z(0)')
# Add a second legend for "Flow" since we can't label in the loop directly
plt.scatter([], [], s=0.2, alpha=0.2, c="olive", label='Flow')
plt.legend()
plt.xticks([])
plt.yticks([])
plt.show()
import matplotlib.pyplot as plt traj = torch.stack(trajectory).cpu().detach().numpy() n = 2000 # 假设 traj 是您的张量,并且 traj.shape = (N, 2000, 2) # 其中 N 是时间点数,2000 是每个时间点的样本数,而 2 是 x 和 y 坐标。 plt.figure(figsize=(6, 6)) # 以黑色绘制第一个时间点 plt.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c="black", label='先验样本 z(S)') # 以橄榄色绘制除第一个和最后一个时间点之外的所有其余时间点 for i in range(1, traj.shape[0]-1): plt.scatter(traj[i, :n, 0], traj[i, :n, 1], s=0.2, alpha=0.2, c="olive") # 以蓝色绘制最后一个时间点 plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c="blue", label='z(0)') # 为“流”添加第二个图例,因为我们无法直接在循环中标记 plt.scatter([], [], s=0.2, alpha=0.2, c="olive", label='流') plt.legend() plt.xticks([]) plt.yticks([]) plt.show()
这里有趣的是,DDIM 的确定性采样最好地恢复了 Flow Matching ODE 样本¶
In [68]
已复制!
inf_size = 1024
sample = vdm.sample_prior((inf_size, 2)).to(DEVICE) # Start with noise
trajectory = [sample]
ts = schedule.generate_schedule()
dts = schedule.discretize()
for dt, t in zip(dts, ts):
full_t = torch.full((inf_size,), t).to(DEVICE)
x_hat = model(sample, full_t) # calculate the vector field based on the definition of the model
# sample = vdm.step_hybrid_sde(x_hat, full_t, sample, dt)
sample = vdm.step_ode(x_hat, full_t, sample, dt)
trajectory.append(sample) # save the trajectory for plotting purposes
traj = torch.stack(trajectory).cpu().detach().numpy()
n = 2000
# Assuming traj is your tensor and traj.shape = (N, 2000, 2)
# where N is the number of time points, 2000 is the number of samples at each time point, and 2 is for the x and y coordinates.
plt.figure(figsize=(6, 6))
# Plot the first time point in black
plt.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c="black", label='Prior sample z(S)')
# Plot all the rest of the time points except the first and last in olive
for i in range(1, traj.shape[0]-1):
plt.scatter(traj[i, :n, 0], traj[i, :n, 1], s=0.2, alpha=0.2, c="olive")
# Plot the last time point in blue
plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c="blue", label='z(0)')
# Add a second legend for "Flow" since we can't label in the loop directly
plt.scatter([], [], s=0.2, alpha=0.2, c="olive", label='Flow')
plt.legend()
plt.xticks([])
plt.yticks([])
plt.show()
inf_size = 1024 sample = vdm.sample_prior((inf_size, 2)).to(DEVICE) # Start with noise trajectory = [sample] ts = schedule.generate_schedule() dts = schedule.discretize() for dt, t in zip(dts, ts): full_t = torch.full((inf_size,), t).to(DEVICE) x_hat = model(sample, full_t) # calculate the vector field based on the definition of the model # sample = vdm.step_hybrid_sde(x_hat, full_t, sample, dt) sample = vdm.step_ode(x_hat, full_t, sample, dt) trajectory.append(sample) # save the trajectory for plotting purposes traj = torch.stack(trajectory).cpu().detach().numpy() n = 2000 # Assuming traj is your tensor and traj.shape = (N, 2000, 2) # where N is the number of time points, 2000 is the number of samples at each time point, and 2 is for the x and y coordinates. plt.figure(figsize=(6, 6)) # Plot the first time point in black plt.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c="black", label='Prior sample z(S)') # Plot all the rest of the time points except the first and last in olive for i in range(1, traj.shape[0]-1): plt.scatter(traj[i, :n, 0], traj[i, :n, 1], s=0.2, alpha=0.2, c="olive") # Plot the last time point in blue plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c="blue", label='z(0)') # Add a second legend for "Flow" since we can't label in the loop directly plt.scatter([], [], s=0.2, alpha=0.2, c="olive", label='Flow') plt.legend() plt.xticks([]) plt.yticks([]) plt.show()
In [69]
已复制!
inf_size = 1024
sample = vdm.sample_prior((inf_size, 2)).to(DEVICE) # Start with noise
trajectory = [sample]
ts = schedule.generate_schedule()
dts = schedule.discretize()
for dt, t in zip(dts, ts):
full_t = torch.full((inf_size,), t).to(DEVICE)
x_hat = model(sample, full_t) # calculate the vector field based on the definition of the model
# sample = vdm.step_hybrid_sde(x_hat, full_t, sample, dt)
sample = vdm.step_ode(x_hat, full_t, sample, dt, temperature = 1.5)
trajectory.append(sample) # save the trajectory for plotting purposes
traj = torch.stack(trajectory).cpu().detach().numpy()
n = 2000
# Assuming traj is your tensor and traj.shape = (N, 2000, 2)
# where N is the number of time points, 2000 is the number of samples at each time point, and 2 is for the x and y coordinates.
plt.figure(figsize=(6, 6))
# Plot the first time point in black
plt.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c="black", label='Prior sample z(S)')
# Plot all the rest of the time points except the first and last in olive
for i in range(1, traj.shape[0]-1):
plt.scatter(traj[i, :n, 0], traj[i, :n, 1], s=0.2, alpha=0.2, c="olive")
# Plot the last time point in blue
plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c="blue", label='z(0)')
# Add a second legend for "Flow" since we can't label in the loop directly
plt.scatter([], [], s=0.2, alpha=0.2, c="olive", label='Flow')
plt.legend()
plt.xticks([])
plt.yticks([])
plt.show()
inf_size = 1024 sample = vdm.sample_prior((inf_size, 2)).to(DEVICE) # Start with noise trajectory = [sample] ts = schedule.generate_schedule() dts = schedule.discretize() for dt, t in zip(dts, ts): full_t = torch.full((inf_size,), t).to(DEVICE) x_hat = model(sample, full_t) # calculate the vector field based on the definition of the model # sample = vdm.step_hybrid_sde(x_hat, full_t, sample, dt) sample = vdm.step_ode(x_hat, full_t, sample, dt, temperature = 1.5) trajectory.append(sample) # save the trajectory for plotting purposes traj = torch.stack(trajectory).cpu().detach().numpy() n = 2000 # Assuming traj is your tensor and traj.shape = (N, 2000, 2) # where N is the number of time points, 2000 is the number of samples at each time point, and 2 is for the x and y coordinates. plt.figure(figsize=(6, 6)) # Plot the first time point in black plt.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c="black", label='Prior sample z(S)') # Plot all the rest of the time points except the first and last in olive for i in range(1, traj.shape[0]-1): plt.scatter(traj[i, :n, 0], traj[i, :n, 1], s=0.2, alpha=0.2, c="olive") # Plot the last time point in blue plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c="blue", label='z(0)') # Add a second legend for "Flow" since we can't label in the loop directly plt.scatter([], [], s=0.2, alpha=0.2, c="olive", label='Flow') plt.legend() plt.xticks([]) plt.yticks([]) plt.show()
In [70]
已复制!
inf_size = 1024
sample = vdm.sample_prior((inf_size, 2)).to(DEVICE) # Start with noise
trajectory = [sample]
ts = schedule.generate_schedule()
dts = schedule.discretize()
for dt, t in zip(dts, ts):
full_t = torch.full((inf_size,), t).to(DEVICE)
x_hat = model(sample, full_t) # calculate the vector field based on the definition of the model
# sample = vdm.step_hybrid_sde(x_hat, full_t, sample, dt)
sample = vdm.step_ode(x_hat, full_t, sample, dt, temperature = 0.5)
trajectory.append(sample) # save the trajectory for plotting purposes
traj = torch.stack(trajectory).cpu().detach().numpy()
n = 2000
# Assuming traj is your tensor and traj.shape = (N, 2000, 2)
# where N is the number of time points, 2000 is the number of samples at each time point, and 2 is for the x and y coordinates.
plt.figure(figsize=(6, 6))
# Plot the first time point in black
plt.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c="black", label='Prior sample z(S)')
# Plot all the rest of the time points except the first and last in olive
for i in range(1, traj.shape[0]-1):
plt.scatter(traj[i, :n, 0], traj[i, :n, 1], s=0.2, alpha=0.2, c="olive")
# Plot the last time point in blue
plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c="blue", label='z(0)')
# Add a second legend for "Flow" since we can't label in the loop directly
plt.scatter([], [], s=0.2, alpha=0.2, c="olive", label='Flow')
plt.legend()
plt.xticks([])
plt.yticks([])
plt.show()
inf_size = 1024 sample = vdm.sample_prior((inf_size, 2)).to(DEVICE) # Start with noise trajectory = [sample] ts = schedule.generate_schedule() dts = schedule.discretize() for dt, t in zip(dts, ts): full_t = torch.full((inf_size,), t).to(DEVICE) x_hat = model(sample, full_t) # calculate the vector field based on the definition of the model # sample = vdm.step_hybrid_sde(x_hat, full_t, sample, dt) sample = vdm.step_ode(x_hat, full_t, sample, dt, temperature = 0.5) trajectory.append(sample) # save the trajectory for plotting purposes traj = torch.stack(trajectory).cpu().detach().numpy() n = 2000 # Assuming traj is your tensor and traj.shape = (N, 2000, 2) # where N is the number of time points, 2000 is the number of samples at each time point, and 2 is for the x and y coordinates. plt.figure(figsize=(6, 6)) # Plot the first time point in black plt.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c="black", label='Prior sample z(S)') # Plot all the rest of the time points except the first and last in olive for i in range(1, traj.shape[0]-1): plt.scatter(traj[i, :n, 0], traj[i, :n, 1], s=0.2, alpha=0.2, c="olive") # Plot the last time point in blue plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c="blue", label='z(0)') # Add a second legend for "Flow" since we can't label in the loop directly plt.scatter([], [], s=0.2, alpha=0.2, c="olive", label='Flow') plt.legend() plt.xticks([]) plt.yticks([]) plt.show()
In [71]
已复制!
inf_size = 1024
sample = vdm.sample_prior((inf_size, 2)).to(DEVICE) # Start with noise
trajectory = [sample]
ts = schedule.generate_schedule()
dts = schedule.discretize()
for dt, t in zip(dts, ts):
full_t = torch.full((inf_size,), t).to(DEVICE)
x_hat = model(sample, full_t) # calculate the vector field based on the definition of the model
sample = vdm.step_hybrid_sde(x_hat, full_t, sample, dt)
# sample = vdm.step_ode(x_hat, full_t, sample, dt)
trajectory.append(sample) # save the trajectory for plotting purposes
traj = torch.stack(trajectory).cpu().detach().numpy()
n = 2000
# Assuming traj is your tensor and traj.shape = (N, 2000, 2)
# where N is the number of time points, 2000 is the number of samples at each time point, and 2 is for the x and y coordinates.
plt.figure(figsize=(6, 6))
# Plot the first time point in black
plt.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c="black", label='Prior sample z(S)')
# Plot all the rest of the time points except the first and last in olive
for i in range(1, traj.shape[0]-1):
plt.scatter(traj[i, :n, 0], traj[i, :n, 1], s=0.2, alpha=0.2, c="olive")
# Plot the last time point in blue
plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c="blue", label='z(0)')
# Add a second legend for "Flow" since we can't label in the loop directly
plt.scatter([], [], s=0.2, alpha=0.2, c="olive", label='Flow')
plt.legend()
plt.xticks([])
plt.yticks([])
plt.show()
inf_size = 1024 sample = vdm.sample_prior((inf_size, 2)).to(DEVICE) # Start with noise trajectory = [sample] ts = schedule.generate_schedule() dts = schedule.discretize() for dt, t in zip(dts, ts): full_t = torch.full((inf_size,), t).to(DEVICE) x_hat = model(sample, full_t) # calculate the vector field based on the definition of the model sample = vdm.step_hybrid_sde(x_hat, full_t, sample, dt) # sample = vdm.step_ode(x_hat, full_t, sample, dt) trajectory.append(sample) # save the trajectory for plotting purposes traj = torch.stack(trajectory).cpu().detach().numpy() n = 2000 # Assuming traj is your tensor and traj.shape = (N, 2000, 2) # where N is the number of time points, 2000 is the number of samples at each time point, and 2 is for the x and y coordinates. plt.figure(figsize=(6, 6)) # Plot the first time point in black plt.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c="black", label='Prior sample z(S)') # Plot all the rest of the time points except the first and last in olive for i in range(1, traj.shape[0]-1): plt.scatter(traj[i, :n, 0], traj[i, :n, 1], s=0.2, alpha=0.2, c="olive") # Plot the last time point in blue plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c="blue", label='z(0)') # Add a second legend for "Flow" since we can't label in the loop directly plt.scatter([], [], s=0.2, alpha=0.2, c="olive", label='Flow') plt.legend() plt.xticks([]) plt.yticks([]) plt.show()