flow: add checks for numerical failure to avoid crashes in SE2::exp
This commit is contained in:
		
							parent
							
								
									bfeda2affa
								
							
						
					
					
						commit
						64a6ab4262
					
				| @ -249,8 +249,11 @@ class FrameToFrameOpticalFlow : public OpticalFlowBase { | ||||
| 
 | ||||
|       PatchT p(old_pyr.lvl(level), old_transform.translation() / scale); | ||||
| 
 | ||||
|       // Perform tracking on current level
 | ||||
|       patch_valid &= trackPointAtLevel(pyr.lvl(level), p, transform); | ||||
|       patch_valid &= p.valid; | ||||
|       if (patch_valid) { | ||||
|         // Perform tracking on current level
 | ||||
|         patch_valid &= trackPointAtLevel(pyr.lvl(level), p, transform); | ||||
|       } | ||||
| 
 | ||||
|       transform.translation() *= scale; | ||||
|     } | ||||
| @ -274,18 +277,24 @@ class FrameToFrameOpticalFlow : public OpticalFlowBase { | ||||
|           transform.linear().matrix() * PatchT::pattern2; | ||||
|       transformed_pat.colwise() += transform.translation(); | ||||
| 
 | ||||
|       bool valid = dp.residual(img_2, transformed_pat, res); | ||||
|       patch_valid &= dp.residual(img_2, transformed_pat, res); | ||||
| 
 | ||||
|       if (valid) { | ||||
|         Vector3 inc = -dp.H_se2_inv_J_se2_T * res; | ||||
|         transform *= SE2::exp(inc).matrix(); | ||||
|       if (patch_valid) { | ||||
|         const Vector3 inc = -dp.H_se2_inv_J_se2_T * res; | ||||
| 
 | ||||
|         const int filter_margin = 2; | ||||
|         // avoid NaN in increment (leads to SE2::exp crashing)
 | ||||
|         patch_valid &= inc.array().isFinite().all(); | ||||
| 
 | ||||
|         if (!img_2.InBounds(transform.translation(), filter_margin)) | ||||
|           patch_valid = false; | ||||
|       } else { | ||||
|         patch_valid = false; | ||||
|         // avoid very large increment
 | ||||
|         patch_valid &= inc.template lpNorm<Eigen::Infinity>() < 1e6; | ||||
| 
 | ||||
|         if (patch_valid) { | ||||
|           transform *= SE2::exp(inc).matrix(); | ||||
| 
 | ||||
|           const int filter_margin = 2; | ||||
| 
 | ||||
|           patch_valid &= img_2.InBounds(transform.translation(), filter_margin); | ||||
|         } | ||||
|       } | ||||
|     } | ||||
| 
 | ||||
|  | ||||
| @ -282,8 +282,11 @@ class MultiscaleFrameToFrameOpticalFlow : public OpticalFlowBase { | ||||
| 
 | ||||
|       PatchT p(old_pyr.lvl(level), old_transform.translation() / scale); | ||||
| 
 | ||||
|       // Perform tracking on current level
 | ||||
|       patch_valid = trackPointAtLevel(pyr.lvl(level), p, transform_tmp); | ||||
|       patch_valid &= p.valid; | ||||
|       if (patch_valid) { | ||||
|         // Perform tracking on current level
 | ||||
|         patch_valid &= trackPointAtLevel(pyr.lvl(level), p, transform_tmp); | ||||
|       } | ||||
| 
 | ||||
|       if (level == static_cast<ssize_t>(pyramid_level) + 1 && !patch_valid) { | ||||
|         return false; | ||||
| @ -315,18 +318,24 @@ class MultiscaleFrameToFrameOpticalFlow : public OpticalFlowBase { | ||||
|           transform.linear().matrix() * PatchT::pattern2; | ||||
|       transformed_pat.colwise() += transform.translation(); | ||||
| 
 | ||||
|       bool valid = dp.residual(img_2, transformed_pat, res); | ||||
|       patch_valid &= dp.residual(img_2, transformed_pat, res); | ||||
| 
 | ||||
|       if (valid) { | ||||
|         Vector3 inc = -dp.H_se2_inv_J_se2_T * res; | ||||
|         transform *= SE2::exp(inc).matrix(); | ||||
|       if (patch_valid) { | ||||
|         const Vector3 inc = -dp.H_se2_inv_J_se2_T * res; | ||||
| 
 | ||||
|         const int filter_margin = 2; | ||||
|         // avoid NaN in increment (leads to SE2::exp crashing)
 | ||||
|         patch_valid &= inc.array().isFinite().all(); | ||||
| 
 | ||||
|         if (!img_2.InBounds(transform.translation(), filter_margin)) | ||||
|           patch_valid = false; | ||||
|       } else { | ||||
|         patch_valid = false; | ||||
|         // avoid very large increment
 | ||||
|         patch_valid &= inc.template lpNorm<Eigen::Infinity>() < 1e6; | ||||
| 
 | ||||
|         if (patch_valid) { | ||||
|           transform *= SE2::exp(inc).matrix(); | ||||
| 
 | ||||
|           const int filter_margin = 2; | ||||
| 
 | ||||
|           patch_valid &= img_2.InBounds(transform.translation(), filter_margin); | ||||
|         } | ||||
|       } | ||||
|     } | ||||
| 
 | ||||
|  | ||||
| @ -65,7 +65,7 @@ struct OpticalFlowPatch { | ||||
| 
 | ||||
|   EIGEN_MAKE_ALIGNED_OPERATOR_NEW | ||||
| 
 | ||||
|   OpticalFlowPatch() { mean = 0; } | ||||
|   OpticalFlowPatch() = default; | ||||
| 
 | ||||
|   OpticalFlowPatch(const Image<const uint16_t> &img, const Vector2 &pos) { | ||||
|     setFromImage(img, pos); | ||||
| @ -127,6 +127,16 @@ struct OpticalFlowPatch { | ||||
|     H_se2.ldlt().solveInPlace(H_se2_inv); | ||||
| 
 | ||||
|     H_se2_inv_J_se2_T = H_se2_inv * J_se2.transpose(); | ||||
| 
 | ||||
|     // NOTE: while it's very unlikely we get a source patch with all black
 | ||||
|     // pixels, since points are usually selected at corners, it doesn't cost
 | ||||
|     // much to be safe here.
 | ||||
| 
 | ||||
|     // all-black patch cannot be normalized; will result in mean of "zero" and
 | ||||
|     // H_se2_inv_J_se2_T will contain "NaN" and data will contain "inf"
 | ||||
|     valid = mean > std::numeric_limits<Scalar>::epsilon() && | ||||
|             H_se2_inv_J_se2_T.array().isFinite().all() && | ||||
|             data.array().isFinite().all(); | ||||
|   } | ||||
| 
 | ||||
|   inline bool residual(const Image<const uint16_t> &img, | ||||
| @ -146,6 +156,12 @@ struct OpticalFlowPatch { | ||||
|       } | ||||
|     } | ||||
| 
 | ||||
|     // all-black patch cannot be normalized
 | ||||
|     if (sum < std::numeric_limits<Scalar>::epsilon()) { | ||||
|       residual.setZero(); | ||||
|       return false; | ||||
|     } | ||||
| 
 | ||||
|     int num_residuals = 0; | ||||
| 
 | ||||
|     for (int i = 0; i < PATTERN_SIZE; i++) { | ||||
| @ -162,14 +178,16 @@ struct OpticalFlowPatch { | ||||
|     return num_residuals > PATTERN_SIZE / 2; | ||||
|   } | ||||
| 
 | ||||
|   Vector2 pos; | ||||
|   VectorP data;  // negative if the point is not valid
 | ||||
|   Vector2 pos = Vector2::Zero(); | ||||
|   VectorP data = VectorP::Zero();  // negative if the point is not valid
 | ||||
| 
 | ||||
|   // MatrixP3 J_se2;  // total jacobian with respect to se2 warp
 | ||||
|   // Matrix3 H_se2_inv;
 | ||||
|   Matrix3P H_se2_inv_J_se2_T; | ||||
|   Matrix3P H_se2_inv_J_se2_T = Matrix3P::Zero(); | ||||
| 
 | ||||
|   Scalar mean; | ||||
|   Scalar mean = 0; | ||||
| 
 | ||||
|   bool valid = false; | ||||
| }; | ||||
| 
 | ||||
| template <typename Scalar, typename Pattern> | ||||
|  | ||||
| @ -236,9 +236,13 @@ class PatchOpticalFlow : public OpticalFlowBase { | ||||
| 
 | ||||
|       transform.translation() /= scale; | ||||
| 
 | ||||
|       // Perform tracking on current level
 | ||||
|       patch_valid &= | ||||
|           trackPointAtLevel(pyr.lvl(level), patch_vec[level], transform); | ||||
|       // TODO: maybe we should better check patch validity when creating points
 | ||||
|       const auto& p = patch_vec[level]; | ||||
|       patch_valid &= p.valid; | ||||
|       if (patch_valid) { | ||||
|         // Perform tracking on current level
 | ||||
|         patch_valid &= trackPointAtLevel(pyr.lvl(level), p, transform); | ||||
|       } | ||||
| 
 | ||||
|       transform.translation() *= scale; | ||||
|     } | ||||
| @ -260,18 +264,24 @@ class PatchOpticalFlow : public OpticalFlowBase { | ||||
|           transform.linear().matrix() * PatchT::pattern2; | ||||
|       transformed_pat.colwise() += transform.translation(); | ||||
| 
 | ||||
|       bool valid = dp.residual(img_2, transformed_pat, res); | ||||
|       patch_valid &= dp.residual(img_2, transformed_pat, res); | ||||
| 
 | ||||
|       if (valid) { | ||||
|         Vector3 inc = -dp.H_se2_inv_J_se2_T * res; | ||||
|         transform *= SE2::exp(inc).matrix(); | ||||
|       if (patch_valid) { | ||||
|         const Vector3 inc = -dp.H_se2_inv_J_se2_T * res; | ||||
| 
 | ||||
|         const int filter_margin = 2; | ||||
|         // avoid NaN in increment (leads to SE2::exp crashing)
 | ||||
|         patch_valid &= inc.array().isFinite().all(); | ||||
| 
 | ||||
|         if (!img_2.InBounds(transform.translation(), filter_margin)) | ||||
|           patch_valid = false; | ||||
|       } else { | ||||
|         patch_valid = false; | ||||
|         // avoid very large increment
 | ||||
|         patch_valid &= inc.template lpNorm<Eigen::Infinity>() < 1e6; | ||||
| 
 | ||||
|         if (patch_valid) { | ||||
|           transform *= SE2::exp(inc).matrix(); | ||||
| 
 | ||||
|           const int filter_margin = 2; | ||||
| 
 | ||||
|           patch_valid &= img_2.InBounds(transform.translation(), filter_margin); | ||||
|         } | ||||
|       } | ||||
|     } | ||||
| 
 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user