Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

revise self._unfold2d(x, ws=8) ? #52

Open
longzeyilang opened this issue Aug 27, 2024 · 3 comments
Open

revise self._unfold2d(x, ws=8) ? #52

longzeyilang opened this issue Aug 27, 2024 · 3 comments

Comments

@longzeyilang
Copy link

HI, I trained my own data. image size about 128*128, and change model` self.block1 = nn.Sequential(
BasicLayer( 1, 8, stride=1),
BasicLayer( 8, 24, stride=1),
BasicLayer( 24, 64, stride=1),
)
self.block2 = nn.Sequential(
BasicLayer(64, 64, stride=2),
BasicLayer(64, 64, stride=1),
BasicLayer(64, 64, stride=1),
)

	self.block3 = nn.Sequential(
									BasicLayer( 64, 128, stride=2),
									BasicLayer(128, 128, stride=1),
									BasicLayer(128, 128, stride=1),
									BasicLayer(128,  64, 1, padding=0),
								 )

	self.block_fusion =  nn.Sequential(
									BasicLayer(64, 64, stride=1),
									BasicLayer(64, 64, stride=1),
									nn.Conv2d (64, 64, 1, padding=0)
								 )

	self.heatmap_head = nn.Sequential(
									BasicLayer(64, 64, 1, padding=0),
									BasicLayer(64, 64, 1, padding=0),
									nn.Conv2d (64, 1, 1),
									nn.Sigmoid()
								)


	self.keypoint_head = nn.Sequential(
									BasicLayer(4, 64, 1, padding=0),
									BasicLayer(64, 64, 1, padding=0),
									BasicLayer(64, 64, 1, padding=0),
									nn.Conv2d (64, 5, 1),
								)

and forward change as follow: def forward(self, x):
"""
input:
x -> torch.Tensor(B, C, H, W) grayscale or rgb images
return:
feats -> torch.Tensor(B, 64, H/8, W/8) dense local features
keypoints -> torch.Tensor(B, 65, H/8, W/8) keypoint logit map
heatmap -> torch.Tensor(B, 1, H/8, W/8) reliability map

	"""
	#dont backprop through normalization
	with torch.no_grad():
		x = x.mean(dim=1, keepdim = True)
		x = self.norm(x)

	#main backbone
	x1 = self.block1(x)
	x2 = self.block2(x1)
	x3 = self.block3(x2)
	x4 = F.interpolate(x3, (x2.shape[-2], x2.shape[-1]), mode='bilinear')
	feats = self.block_fusion(x4 + x2)
	
	#heads
	heatmap = self.heatmap_head(feats)                        # Reliability map
	keypoints = self.keypoint_head(self._unfold2d(x, ws=2))   # Keypoint map logits
	return feats, keypoints, heatmap`

the unflod2d ws change to 2, how to revise keypoint_head ? and how to revise losses.py?
thank you

@longzeyilang
Copy link
Author

`import torch
import torch.nn.functional as F

from modules.dataset.megadepth import megadepth_warper

from modules.training import utils

from third_party.alike_wrapper import extract_alike_kpts

from modules.model_small import UNFLOD_WS

def dual_softmax_loss(X, Y, temp = 0.2):
if X.size() != Y.size() or X.dim() != 2 or Y.dim() != 2:
raise RuntimeError('Error: X and Y shapes must match and be 2D matrices')

dist_mat = (X @ Y.t()) * temp
conf_matrix12 = F.log_softmax(dist_mat, dim=1)
conf_matrix21 = F.log_softmax(dist_mat.t(), dim=1)

with torch.no_grad():
    conf12 = torch.exp( conf_matrix12 ).max(dim=-1)[0]
    conf21 = torch.exp( conf_matrix21 ).max(dim=-1)[0]
    conf = conf12 * conf21

target = torch.arange(len(X), device = X.device)

loss = F.nll_loss(conf_matrix12, target) + \
       F.nll_loss(conf_matrix21, target)

return loss, conf

def smooth_l1_loss(input, target, beta=2.0, size_average=True):
diff = torch.abs(input - target)
loss = torch.where(diff < beta, 0.5 * diff ** 2 / beta, diff - 0.5 * beta)
return loss.mean() if size_average else loss.sum()

def fine_loss(f1, f2, pts1, pts2, fine_module, ws=7):
'''
Compute Fine features and spatial loss
'''
C, H, W = f1.shape
N = len(pts1)

#Sort random offsets
with torch.no_grad():
    a = -(ws//2)
    b = (ws//2)
    offset_gt = (a - b) * torch.rand(N, 2, device = f1.device) + b
    pts2_random = pts2 + offset_gt

#pdb.set_trace()
patches1 = utils.crop_patches(f1.unsqueeze(0), (pts1+0.5).long(), size=ws).view(C, N, ws * ws).permute(1, 2, 0) #[N, ws*ws, C]
patches2 = utils.crop_patches(f2.unsqueeze(0), (pts2_random+0.5).long(), size=ws).view(C, N, ws * ws).permute(1, 2, 0)  #[N, ws*ws, C]

#Apply transformer
patches1, patches2 = fine_module(patches1, patches2)

features = patches1.view(N, ws, ws, C)[:, ws//2, ws//2, :].view(N, 1, 1, C) # [N, 1, 1, C]
patches2 = patches2.view(N, ws, ws, C) # [N, w, w, C]

#Dot Product
heatmap_match = (features * patches2).sum(-1)
offset_coords = utils.subpix_softmax2d(heatmap_match)

#Invert offset because center crop inverts it
offset_gt = -offset_gt 

#MSE
error = ((offset_coords - offset_gt)**2).sum(-1).mean()
return error

def alike_distill_loss(kpts, img):
C, H, W = kpts.shape
kpts = kpts.permute(1,2,0)
img = img.permute(1,2,0).expand(-1,-1,3).cpu().numpy() * 255

with torch.no_grad():
    alike_kpts = torch.tensor(extract_alike_kpts(img), device=kpts.device)
    labels = torch.ones((H, W), dtype = torch.long, device = kpts.device) * UNFLOD_WS*UNFLOD_WS # -> Default is non-keypoint (bin 64)
    offsets = (((alike_kpts/UNFLOD_WS) - (alike_kpts/UNFLOD_WS).long())*UNFLOD_WS).long()
    offsets =  offsets[:, 0] + UNFLOD_WS*offsets[:, 1]  # Linear IDX
    labels[(alike_kpts[:,1]/UNFLOD_WS).long(), (alike_kpts[:,0]/UNFLOD_WS).long()] = offsets

kpts = kpts.view(-1,C)
labels = labels.view(-1)

mask = labels < UNFLOD_WS*UNFLOD_WS
idxs_pos = mask.nonzero().flatten()
idxs_neg = (~mask).nonzero().flatten()
perm = torch.randperm(idxs_neg.size(0))[:len(idxs_pos)//32]
idxs_neg = idxs_neg[perm]
idxs = torch.cat([idxs_pos, idxs_neg])

kpts = kpts[idxs]
labels = labels[idxs]

with torch.no_grad():
    predicted = kpts.max(dim=-1)[1]
    acc =  (labels == predicted)
    acc = acc.sum() / len(acc)

kpts = F.log_softmax(kpts)
loss = F.nll_loss(kpts, labels, reduction = 'mean')

return loss, acc

def keypoint_position_loss(kpts1, kpts2, pts1, pts2, softmax_temp = 1.0):
'''
Computes coordinate classification loss, by re-interpreting the 64 bins to 8x8 grid and optimizing
for correct offsets
'''
C, H, W = kpts1.shape
kpts1 = kpts1.permute(1,2,0) * softmax_temp
kpts2 = kpts2.permute(1,2,0) * softmax_temp

with torch.no_grad():
    #Generate meshgrid
    x, y = torch.meshgrid(torch.arange(W, device=kpts1.device), torch.arange(H, device=kpts1.device), indexing ='xy')
    xy = torch.cat([x.unsqueeze(-1), y.unsqueeze(-1)], dim=-1)
    xy*=8

    #Generate collision map
    hashmap = torch.ones((H*UNFLOD_WS, W*UNFLOD_WS, 2), dtype = torch.long, device = kpts1.device) * -1
    hashmap[(pts1[:,1]).long(), (pts1[:,0]).long(), :] = (pts2).long()

    #Estimate offset of src kpts 
    _, kpts1_offsets = kpts1.max(dim=-1)
    kpts1_offsets_x = kpts1_offsets  % UNFLOD_WS
    kpts1_offsets_y = kpts1_offsets // UNFLOD_WS
    kpts1_offsets_xy = torch.cat([kpts1_offsets_x.unsqueeze(-1), 
                                  kpts1_offsets_y.unsqueeze(-1)], dim=-1)
    #pdb.set_trace()
    kpts1_coords = xy + kpts1_offsets_xy

    #find src -> tgt pts
    kpts1_coords = kpts1_coords.view(-1,2)
    gt_12 = hashmap[kpts1_coords[:,1], kpts1_coords[:,0]]
    mask_valid = torch.all(gt_12 >= 0, dim=-1)
    gt_12 = gt_12[mask_valid]

    #find offset labels
    labels2 = (gt_12/UNFLOD_WS) - (gt_12/UNFLOD_WS).long()
    labels2 = (labels2 * UNFLOD_WS).long()
    labels2 = labels2[:, 0] + UNFLOD_WS*labels2[:, 1] #linear index
    
kpts2_selected = kpts2[(gt_12[:, 1]/UNFLOD_WS).long(), (gt_12[:, 0]/UNFLOD_WS).long()]        

kpts1_selected = F.log_softmax(kpts1.view(-1,C)[mask_valid], dim=-1)
kpts2_selected = F.log_softmax(kpts2_selected, dim=-1)

#Here we enforce softmax to keep current max on src kps
with torch.no_grad():
    _, labels1 =  kpts1_selected.max(dim=-1)

predicted2 = kpts2_selected.max(dim=-1)[1]
acc =  (labels2 == predicted2)
acc = acc.sum() / len(acc)

loss = F.nll_loss(kpts1_selected, labels1, reduction = 'mean') + \
       F.nll_loss(kpts2_selected, labels2, reduction = 'mean')

return loss, acc

def coordinate_classification_loss(coords1, pts1, pts2, conf):
'''
Computes the fine coordinate classification loss, by re-interpreting the 64 bins to 8x8 grid and optimizing
for correct offsets after warp
'''
#Do not backprop coordinate warps
with torch.no_grad():
coords1_detached = pts1 * UNFLOD_WS
#find offset
offsets1_detached = (coords1_detached/UNFLOD_WS) - (coords1_detached/UNFLOD_WS).long()
offsets1_detached = (offsets1_detached * UNFLOD_WS).long()
labels1 = offsets1_detached[:, 0] + UNFLOD_WS*offsets1_detached[:, 1]

#pdb.set_trace()
coords1_log = F.log_softmax(coords1, dim=-1)

predicted = coords1.max(dim=-1)[1]
acc =  (labels1 == predicted)
acc = acc[conf > 0.1]
acc = acc.sum() / len(acc)

loss = F.nll_loss(coords1_log, labels1, reduction = 'none')

#Weight loss by confidence, giving more emphasis on reliable matches
conf = conf / conf.sum()
loss = (loss * conf).sum()

return loss * 2., acc

def keypoint_loss(heatmap, target):
# Compute L1 loss
L1_loss = F.l1_loss(heatmap, target)
return L1_loss * 3.0

def hard_triplet_loss(X,Y, margin = 0.5):

if X.size() != Y.size() or X.dim() != 2 or Y.dim() != 2:
    raise RuntimeError('Error: X and Y shapes must match and be 2D matrices')

dist_mat = torch.cdist(X, Y, p=2.0)
dist_pos = torch.diag(dist_mat)
dist_neg = dist_mat + 100.*torch.eye(*dist_mat.size(), dtype = dist_mat.dtype, 
        device = dist_mat.get_device() if dist_mat.is_cuda else torch.device("cpu"))
#filter repeated patches on negative distances to avoid weird stuff on gradients
dist_neg = dist_neg + dist_neg.le(0.01).float()*100.

#Margin Ranking Loss
hard_neg = torch.min(dist_neg, 1)[0]
loss = torch.clamp(margin + dist_pos - hard_neg, min=0.)
return loss.mean()

`
the loss file and UNFLOD_WS=2, please check

@longzeyilang
Copy link
Author

@guipotje

@guipotje
Copy link
Collaborator

Hi @longzeyilang,

After a quick review, it seems your updates are in theory correct, the only problem I see is that a 2x2 patch provides too little context for the keypoint head to be effective.

What kind of issues are you experiencing?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants