Aligning two coronal sections of adult mouse brain from MERFISH

In this notebook, we align two single cell resolution spatial transcriptomics datasets of coronal sections of the adult mouse brain from matched locations with respect to bregma assayed by MERFISH.

We will use STalign to achieve this alignment. We will first load the relevant code libraries.

[1]:
## import dependencies
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import torch
import plotly
import requests

# import STalign
from STalign import STalign

# make plots bigger
plt.rcParams["figure.figsize"] = (12,10)

We have already downloaded single cell spatial transcriptomics datasets and placed the files in a folder called merfish_data.

We can read in the cell information for the first dataset using pandas as pd.

[2]:
# Single cell data 1
# read in data
fname = '../merfish_data/datasets_mouse_brain_map_BrainReceptorShowcase_Slice2_Replicate3_cell_metadata_S2R3.csv.gz'
df1 = pd.read_csv(fname)
print(df1.head())
                                Unnamed: 0  fov       volume    center_x
0  158338042824236264719696604356349910479   33   532.778772  617.916619  \
1  260594727341160372355976405428092853003   33  1004.430016  596.808018
2  307643940700812339199503248604719950662   33  1267.183208  578.880018
3   30863303465976316429997331474071348973   33  1403.401822  572.616017
4  313162718584097621688679244357302162401   33   507.949497  608.364018

      center_y       min_x       max_x        min_y        max_y
0  2666.520010  614.725219  621.108019  2657.545209  2675.494810
1  2763.450012  589.669218  603.946818  2757.013212  2769.886812
2  2748.978012  570.877217  586.882818  2740.489211  2757.466812
3  2766.690012  564.937217  580.294818  2756.581212  2776.798812
4  2687.418010  603.061218  613.666818  2682.493210  2692.342810

For alignment with STalign, we only need the cell centroid information so we can pull out this information. We can further visualize the cell centroids to get a sense of the variation in cell density that we will be relying on for our alignment by plotting using matplotlib.pyplot as plt.

[3]:
# get cell centroid coordinates
xI = np.array(df1['center_x'])
yI = np.array(df1['center_y'])

# plot
fig,ax = plt.subplots()
ax.scatter(xI,yI,s=1,alpha=0.2, label='source')
ax.legend(markerscale = 10)
[3]:
<matplotlib.legend.Legend at 0x1257e5f50>
../_images/notebooks_merfish-merfish-alignment_5_1.png

Now, we can repeat this to get cell information from the second dataset.

[4]:
# Single cell data 2
# read in data
fname = '../merfish_data/datasets_mouse_brain_map_BrainReceptorShowcase_Slice2_Replicate2_cell_metadata_S2R2.csv.gz'

df2 = pd.read_csv(fname)

# get cell centroids
xJ = np.array(df2['center_x'])
yJ = np.array(df2['center_y'])

# plot
fig,ax = plt.subplots()
ax.scatter(xJ,yJ,s=1,alpha=0.2,c='#ff7f0e', label='target')
ax.legend(markerscale = 10)
[4]:
<matplotlib.legend.Legend at 0x129366d90>
../_images/notebooks_merfish-merfish-alignment_7_1.png

Note that plotting the cell centroid positions from both datasets shows that non-linear local alignment is needed.

[5]:
# plot
fig,ax = plt.subplots()
ax.scatter(xI,yI,s=1,alpha=0.2, label='source')
ax.scatter(xJ,yJ,s=1,alpha=0.1, label= 'target')
ax.legend(markerscale = 10)
[5]:
<matplotlib.legend.Legend at 0x129a6e210>
../_images/notebooks_merfish-merfish-alignment_9_1.png

STalign relies on an interative gradient descent to align these two images. This performs quicker and better if the source and target are initially at a similar angle.

Evaluate the similarity of the rotation angle by viewing a side by side comparison. Change the value of theta_deg, until the rotation angle is similar. Note: the rotation here is defined in degrees and is in the clockwise direction.

The angle chosen will be used to construct a 2x2 rotation matrix L and a 2 element translation vector T.

[6]:
theta_deg = 45
theta0 = (np.pi/180)*-theta_deg

#rotation matrix
#rotates about the origin
L = np.array([[np.cos(theta0),-np.sin(theta0)],
              [np.sin(theta0),np.cos(theta0)]])

source_L = np.matmul(L , np.array([xI, yI]))
xI_L = source_L[0]
yI_L = source_L[1]

#translation matrix
#effectively makes the rotation about the centroid of I (i.e the means of xI and yI])
#and also moves the centroid of I to the centroid of J
T = np.array([ np.mean(xI)- np.cos(theta0)*np.mean(xI) +np.sin(theta0)*np.mean(yI) - (np.mean(xI)-np.mean(xJ)),
              np.mean(yI)- np.sin(theta0)*np.mean(xI) -np.cos(theta0)*np.mean(yI) - (np.mean(yI)-np.mean(yJ))])

xI_L_T = xI_L + T[0]
yI_L_T = yI_L + T[1]


fig,ax = plt.subplots()
ax.scatter(xI_L_T,yI_L_T,s=1,alpha=0.1, label='source with initial affine transformation')
ax.scatter(xJ,yJ,s=1,alpha=0.1, label = 'target')
ax.legend(markerscale = 10)
[6]:
<matplotlib.legend.Legend at 0x129a47290>
../_images/notebooks_merfish-merfish-alignment_11_1.png

Now, we will first use STalign to rasterize the single cell centroid positions into an image. Assuming the single-cell centroid coordinates are in microns, we will perform this rasterization at a 30 micron resolution. We can visualize the resulting rasterized image.

Note that points are plotting with the origin at bottom left while images are typically plotted with origin at top left so we’ve used invert_yaxis() to invert the yaxis for visualization consistency.

[7]:
# rasterize at 30um resolution (assuming positions are in um units) and plot
XI,YI,I,fig = STalign.rasterize(xI_L_T,yI_L_T,dx=15,blur=1.5)

# plot
ax = fig.axes[0]
ax.invert_yaxis()
0 of 85958
10000 of 85958
20000 of 85958
30000 of 85958
40000 of 85958
50000 of 85958
60000 of 85958
70000 of 85958
80000 of 85958
85957 of 85958
../_images/notebooks_merfish-merfish-alignment_13_1.png

Repeat rasterization for target dataset.

[8]:
# rasterize and plot
XJ,YJ,J,fig = STalign.rasterize(xJ,yJ,dx=15, blur=1.5)
ax = fig.axes[0]
ax.invert_yaxis()
0 of 84172
10000 of 84172
20000 of 84172
30000 of 84172
40000 of 84172
50000 of 84172
60000 of 84172
70000 of 84172
80000 of 84172
84171 of 84172
../_images/notebooks_merfish-merfish-alignment_15_1.png

We can also plot the rasterized images next to each other.

[9]:
# get extent of images
extentI = STalign.extent_from_x((YI,XI))
extentJ = STalign.extent_from_x((YJ,XJ))

# plot rasterized images
fig,ax = plt.subplots(1,2)
ax[0].imshow(I[0], extent=extentI)
ax[1].imshow(J[0], extent=extentJ)
ax[0].invert_yaxis()
ax[1].invert_yaxis()
../_images/notebooks_merfish-merfish-alignment_17_0.png

Now we will perform our alignment. There are many parameters that can be tuned for performing this alignment. If we don’t specify parameters, defaults will be used.

[10]:
%%time
# run LDDMM
# specify device (default device for STalign.LDDMM is cpu)
if torch.cuda.is_available():
    device = 'cuda:0'
else:
    device = 'cpu'

# keep all other parameters default
params = {
            'niter': 10000,
            'device':device,
            'epV': 50
          }

Ifoo = np.vstack((I, I, I)) # make RGB instead of greyscale
Jfoo = np.vstack((J, J, J)) # make RGB instead of greyscale
out = STalign.LDDMM([YI,XI],Ifoo,[YJ,XJ],Jfoo,**params)
/Users/gohtaaihara/.local/share/virtualenvs/STalign-oDkYRjeO/lib/python3.11/site-packages/STalign/STalign.py:1043: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  L = torch.tensor(L,device=device,dtype=dtype,requires_grad=True)
/Users/gohtaaihara/.local/share/virtualenvs/STalign-oDkYRjeO/lib/python3.11/site-packages/STalign/STalign.py:1044: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  T = torch.tensor(T,device=device,dtype=dtype,requires_grad=True)
/Users/gohtaaihara/.local/share/virtualenvs/STalign-oDkYRjeO/lib/python3.11/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/native/TensorShape.cpp:3484.)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
/Users/gohtaaihara/.local/share/virtualenvs/STalign-oDkYRjeO/lib/python3.11/site-packages/STalign/STalign.py:1301: UserWarning: Data has no positive values, and therefore cannot be log-scaled.
  axE[2].set_yscale('log')
CPU times: user 2h 39min 42s, sys: 31min 2s, total: 3h 10min 45s
Wall time: 1h 27min 7s
../_images/notebooks_merfish-merfish-alignment_19_2.png
../_images/notebooks_merfish-merfish-alignment_19_3.png
../_images/notebooks_merfish-merfish-alignment_19_4.png
[11]:
# get necessary output variables
A = out['A']
v = out['v']
xv = out['xv']

Plots generated throughout the alignment can be used to give you a sense of whether the parameter choices are appropriate and whether your alignment is converging on a solution.

We can also evaluate the resulting alignment by applying the transformation to visualize how our source and target images were deformed to achieve the alignment.

[12]:
# apply transform
phii = STalign.build_transform(xv,v,A,XJ=[YJ,XJ],direction='b')
phiI = STalign.transform_image_atlas_to_target(xv,v,A,[YI,XI],Ifoo,[YJ,XJ])

# plot with grids
fig,ax = plt.subplots()
levels = np.arange(-100000,100000,1000)
ax.contour(XJ,YJ,phii[...,0],colors='r',linestyles='-',levels=levels)
ax.contour(XJ,YJ,phii[...,1],colors='g',linestyles='-',levels=levels)
ax.set_aspect('equal')
ax.set_title('source to target')
ax.imshow(phiI.permute(1,2,0)/torch.max(phiI),extent=extentJ)
ax.invert_yaxis()
/Users/gohtaaihara/.local/share/virtualenvs/STalign-oDkYRjeO/lib/python3.11/site-packages/STalign/STalign.py:1660: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  A = torch.tensor(A)
/Users/gohtaaihara/.local/share/virtualenvs/STalign-oDkYRjeO/lib/python3.11/site-packages/STalign/STalign.py:1661: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  if v is not None: v = torch.tensor(v)
/Users/gohtaaihara/.local/share/virtualenvs/STalign-oDkYRjeO/lib/python3.11/site-packages/STalign/STalign.py:1673: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  XJ = torch.tensor(XJ)
../_images/notebooks_merfish-merfish-alignment_22_1.png

Note that because of our use of LDDMM, the resulting transformation is invertible.

[13]:
# transform is invertible
phi = STalign.build_transform(xv,v,A,XJ=[YI,XI],direction='f')
phiiJ = STalign.transform_image_target_to_atlas(xv,v,A,[YJ,XJ],Jfoo,[YI,XI])

# plot with grids
fig,ax = plt.subplots()
levels = np.arange(-100000,100000,1000)
ax.contour(XI,YI,phi[...,0],colors='r',linestyles='-',levels=levels)
ax.contour(XI,YI,phi[...,1],colors='g',linestyles='-',levels=levels)
ax.set_aspect('equal')
ax.set_title('target to source')
ax.imshow(phiiJ.permute(1,2,0)/torch.max(phiiJ),extent=extentI)
ax.invert_yaxis()
../_images/notebooks_merfish-merfish-alignment_24_0.png

Finally, we can apply our STalign transform to the original sets of single cell centroid positions (with initial affine transformation) to achieve their new aligned positions.

[14]:
# apply transform to original points
tpointsI= STalign.transform_points_atlas_to_target(xv,v,A, np.stack([yI_L_T, xI_L_T], 1))

#switch from row column coordinates (y,x) to (x,y)
xI_LDDMM = tpointsI[:,1]
yI_LDDMM = tpointsI[:,0]

And we can visualize the results.

[15]:
# plot results
fig,ax = plt.subplots()
ax.scatter(xI,yI,s=1,alpha=0.1, label='source')
ax.scatter(xI_LDDMM,yI_LDDMM,s=1,alpha=0.1, label = 'source aligned')
ax.scatter(xJ,yJ,s=1,alpha=0.1, label='target')
ax.legend(markerscale = 10)
[15]:
<matplotlib.legend.Legend at 0x12b282410>
../_images/notebooks_merfish-merfish-alignment_28_1.png
[16]:
fig,ax = plt.subplots(1,2)
ax[0].scatter(xI_L_T,yI_L_T,s=1,alpha=0.1, label='source with initial affine transformation')
ax[0].scatter(xJ,yJ,s=1,alpha=0.1, label='target')
ax[1].scatter(xI_LDDMM,yI_LDDMM,s=1,alpha=0.1, label = 'source STaligned')
ax[1].scatter(xJ,yJ,s=1,alpha=0.1, label='target')
ax[0].legend(markerscale = 10, loc = 'lower left')
ax[1].legend(markerscale = 10, loc = 'lower left')
[16]:
<matplotlib.legend.Legend at 0x12b4cf2d0>
../_images/notebooks_merfish-merfish-alignment_29_1.png

And save the new aligned positions by appending to our original data

[17]:
df3 = pd.DataFrame(

    {

        "aligned_x": xI_LDDMM,

        "aligned_y": yI_LDDMM,

    },


)

results = pd.concat([df1, df3], axis=1)
results.head()
[17]:
Unnamed: 0 fov volume center_x center_y min_x max_x min_y max_y aligned_x aligned_y
0 158338042824236264719696604356349910479 33 532.778772 617.916619 2666.520010 614.725219 621.108019 2657.545209 2675.494810 1073.122538 7618.842417
1 260594727341160372355976405428092853003 33 1004.430016 596.808018 2763.450012 589.669218 603.946818 2757.013212 2769.886812 1113.700116 7726.863827
2 307643940700812339199503248604719950662 33 1267.183208 578.880018 2748.978012 570.877217 586.882818 2740.489211 2757.466812 1092.264086 7726.675496
3 30863303465976316429997331474071348973 33 1403.401822 572.616017 2766.690012 564.937217 580.294818 2756.581212 2776.798812 1097.931997 7748.184470
4 313162718584097621688679244357302162401 33 507.949497 608.364018 2687.418010 603.061218 613.666818 2682.493210 2692.342810 1078.207201 7645.926869

We will finally create a compressed .csv.gz file named mouse_brain_map_BrainReceptorShowcase_Slice2_Replicate3_STalign_to_Slice2_Replicate2.csv.gz

[18]:
results.to_csv('../merfish_data/mouse_brain_map_BrainReceptorShowcase_Slice2_Replicate3_STalign_to_Slice2_Replicate2.csv.gz',
               compression='gzip')