diff --git a/include/basalt/optical_flow/frame_to_frame_optical_flow.h b/include/basalt/optical_flow/frame_to_frame_optical_flow.h index 530c645..0cfd752 100644 --- a/include/basalt/optical_flow/frame_to_frame_optical_flow.h +++ b/include/basalt/optical_flow/frame_to_frame_optical_flow.h @@ -249,8 +249,11 @@ class FrameToFrameOpticalFlow : public OpticalFlowBase { PatchT p(old_pyr.lvl(level), old_transform.translation() / scale); - // Perform tracking on current level - patch_valid &= trackPointAtLevel(pyr.lvl(level), p, transform); + patch_valid &= p.valid; + if (patch_valid) { + // Perform tracking on current level + patch_valid &= trackPointAtLevel(pyr.lvl(level), p, transform); + } transform.translation() *= scale; } @@ -274,18 +277,24 @@ class FrameToFrameOpticalFlow : public OpticalFlowBase { transform.linear().matrix() * PatchT::pattern2; transformed_pat.colwise() += transform.translation(); - bool valid = dp.residual(img_2, transformed_pat, res); + patch_valid &= dp.residual(img_2, transformed_pat, res); - if (valid) { - Vector3 inc = -dp.H_se2_inv_J_se2_T * res; - transform *= SE2::exp(inc).matrix(); + if (patch_valid) { + const Vector3 inc = -dp.H_se2_inv_J_se2_T * res; - const int filter_margin = 2; + // avoid NaN in increment (leads to SE2::exp crashing) + patch_valid &= inc.array().isFinite().all(); - if (!img_2.InBounds(transform.translation(), filter_margin)) - patch_valid = false; - } else { - patch_valid = false; + // avoid very large increment + patch_valid &= inc.template lpNorm() < 1e6; + + if (patch_valid) { + transform *= SE2::exp(inc).matrix(); + + const int filter_margin = 2; + + patch_valid &= img_2.InBounds(transform.translation(), filter_margin); + } } } diff --git a/include/basalt/optical_flow/multiscale_frame_to_frame_optical_flow.h b/include/basalt/optical_flow/multiscale_frame_to_frame_optical_flow.h index 02d6132..f30e928 100644 --- a/include/basalt/optical_flow/multiscale_frame_to_frame_optical_flow.h +++ b/include/basalt/optical_flow/multiscale_frame_to_frame_optical_flow.h @@ -282,8 +282,11 @@ class MultiscaleFrameToFrameOpticalFlow : public OpticalFlowBase { PatchT p(old_pyr.lvl(level), old_transform.translation() / scale); - // Perform tracking on current level - patch_valid = trackPointAtLevel(pyr.lvl(level), p, transform_tmp); + patch_valid &= p.valid; + if (patch_valid) { + // Perform tracking on current level + patch_valid &= trackPointAtLevel(pyr.lvl(level), p, transform_tmp); + } if (level == static_cast(pyramid_level) + 1 && !patch_valid) { return false; @@ -315,18 +318,24 @@ class MultiscaleFrameToFrameOpticalFlow : public OpticalFlowBase { transform.linear().matrix() * PatchT::pattern2; transformed_pat.colwise() += transform.translation(); - bool valid = dp.residual(img_2, transformed_pat, res); + patch_valid &= dp.residual(img_2, transformed_pat, res); - if (valid) { - Vector3 inc = -dp.H_se2_inv_J_se2_T * res; - transform *= SE2::exp(inc).matrix(); + if (patch_valid) { + const Vector3 inc = -dp.H_se2_inv_J_se2_T * res; - const int filter_margin = 2; + // avoid NaN in increment (leads to SE2::exp crashing) + patch_valid &= inc.array().isFinite().all(); - if (!img_2.InBounds(transform.translation(), filter_margin)) - patch_valid = false; - } else { - patch_valid = false; + // avoid very large increment + patch_valid &= inc.template lpNorm() < 1e6; + + if (patch_valid) { + transform *= SE2::exp(inc).matrix(); + + const int filter_margin = 2; + + patch_valid &= img_2.InBounds(transform.translation(), filter_margin); + } } } diff --git a/include/basalt/optical_flow/patch.h b/include/basalt/optical_flow/patch.h index 246e342..aee8e9a 100644 --- a/include/basalt/optical_flow/patch.h +++ b/include/basalt/optical_flow/patch.h @@ -65,7 +65,7 @@ struct OpticalFlowPatch { EIGEN_MAKE_ALIGNED_OPERATOR_NEW - OpticalFlowPatch() { mean = 0; } + OpticalFlowPatch() = default; OpticalFlowPatch(const Image &img, const Vector2 &pos) { setFromImage(img, pos); @@ -127,6 +127,16 @@ struct OpticalFlowPatch { H_se2.ldlt().solveInPlace(H_se2_inv); H_se2_inv_J_se2_T = H_se2_inv * J_se2.transpose(); + + // NOTE: while it's very unlikely we get a source patch with all black + // pixels, since points are usually selected at corners, it doesn't cost + // much to be safe here. + + // all-black patch cannot be normalized; will result in mean of "zero" and + // H_se2_inv_J_se2_T will contain "NaN" and data will contain "inf" + valid = mean > std::numeric_limits::epsilon() && + H_se2_inv_J_se2_T.array().isFinite().all() && + data.array().isFinite().all(); } inline bool residual(const Image &img, @@ -146,6 +156,12 @@ struct OpticalFlowPatch { } } + // all-black patch cannot be normalized + if (sum < std::numeric_limits::epsilon()) { + residual.setZero(); + return false; + } + int num_residuals = 0; for (int i = 0; i < PATTERN_SIZE; i++) { @@ -162,14 +178,16 @@ struct OpticalFlowPatch { return num_residuals > PATTERN_SIZE / 2; } - Vector2 pos; - VectorP data; // negative if the point is not valid + Vector2 pos = Vector2::Zero(); + VectorP data = VectorP::Zero(); // negative if the point is not valid // MatrixP3 J_se2; // total jacobian with respect to se2 warp // Matrix3 H_se2_inv; - Matrix3P H_se2_inv_J_se2_T; + Matrix3P H_se2_inv_J_se2_T = Matrix3P::Zero(); - Scalar mean; + Scalar mean = 0; + + bool valid = false; }; template diff --git a/include/basalt/optical_flow/patch_optical_flow.h b/include/basalt/optical_flow/patch_optical_flow.h index 0710bd4..81c891d 100644 --- a/include/basalt/optical_flow/patch_optical_flow.h +++ b/include/basalt/optical_flow/patch_optical_flow.h @@ -236,9 +236,13 @@ class PatchOpticalFlow : public OpticalFlowBase { transform.translation() /= scale; - // Perform tracking on current level - patch_valid &= - trackPointAtLevel(pyr.lvl(level), patch_vec[level], transform); + // TODO: maybe we should better check patch validity when creating points + const auto& p = patch_vec[level]; + patch_valid &= p.valid; + if (patch_valid) { + // Perform tracking on current level + patch_valid &= trackPointAtLevel(pyr.lvl(level), p, transform); + } transform.translation() *= scale; } @@ -260,18 +264,24 @@ class PatchOpticalFlow : public OpticalFlowBase { transform.linear().matrix() * PatchT::pattern2; transformed_pat.colwise() += transform.translation(); - bool valid = dp.residual(img_2, transformed_pat, res); + patch_valid &= dp.residual(img_2, transformed_pat, res); - if (valid) { - Vector3 inc = -dp.H_se2_inv_J_se2_T * res; - transform *= SE2::exp(inc).matrix(); + if (patch_valid) { + const Vector3 inc = -dp.H_se2_inv_J_se2_T * res; - const int filter_margin = 2; + // avoid NaN in increment (leads to SE2::exp crashing) + patch_valid &= inc.array().isFinite().all(); - if (!img_2.InBounds(transform.translation(), filter_margin)) - patch_valid = false; - } else { - patch_valid = false; + // avoid very large increment + patch_valid &= inc.template lpNorm() < 1e6; + + if (patch_valid) { + transform *= SE2::exp(inc).matrix(); + + const int filter_margin = 2; + + patch_valid &= img_2.InBounds(transform.translation(), filter_margin); + } } }