Исходный код denoise_bot.ML_component.admm_core

from typing import List, Tuple

from .utils import ArrayLike, BackendModule

try:
    import cupy  # noqa F401
except ImportError:
    pass


[документация] def MC_update_X(Z: "ArrayLike", lambda_: "ArrayLike", r: float, np: "BackendModule") -> "ArrayLike": """ Обновление матрицы X. Шаг обновления низкоранговой матрицы X через пороговое сжатие сингулярных чисел (SVT). Это соответствует решению задачи: argmin_X ||X||_* + (r/2)||X - Z + U||^2_F Args: Z (ArrayLike): Текущая оценка матрицы, удовлетворяющей ограничениям. lambda_ (ArrayLike): Двойственная переменная (масштабированная). r (float): Параметр rho для ADMM, контролирующий штраф. np (BackendModule): Вычислительный бэкенд (модуль numpy или cupy). Returns: ArrayLike: Обновленная низкоранговая матрица X того же типа, что и входные. """ matrix_to_decompose = Z - lambda_ / r u, s, vt = np.linalg.svd(matrix_to_decompose, full_matrices=False) # Пороговое сжатие (soft-thresholding) сингулярных чисел threshold = 1.0 / r s_thresholded = np.maximum(s - threshold, 0) return u @ np.diag(s_thresholded) @ vt
[документация] def MC_update_Z(X: "ArrayLike", lambda_: "ArrayLike", r: float, Y: "ArrayLike", mask: "ArrayLike") -> "ArrayLike": """ Обновление матрицы Z. Шаг обновления матрицы Z, удовлетворяющей ограничениям на известные пиксели. Это соответствует проекции на множество матриц, совпадающих с Y на маске E. Args: X (ArrayLike): Текущая оценка низкоранговой матрицы. lambda_ (ArrayLike): Двойственная переменная. r (float): Параметр rho для ADMM. Y (ArrayLike): Исходная матрица с пропусками (используется для известных значений). mask (ArrayLike): Булева маска, где True - известные элементы. Returns: ArrayLike: Обновленная матрица Z. """ # Сначала вычисляем временную матрицу Z_new = X + lambda_ / r # Затем принудительно устанавливаем известные значения из исходной матрицы Y Z_new[mask] = Y[mask] return Z_new
[документация] def MC_ADMM( Y: "ArrayLike", mask: "ArrayLike", tol: float, max_iters: int, r: float, backend: "BackendModule", ) -> Tuple["ArrayLike", List[float]]: """ Дозаполняет матрицу с помощью ADMM. Решает задачу матричного дозаполнения с помощью ADMM, управляя итерационным процессом. Args: Y (ArrayLike): Исходная матрица (канал изображения) с пропусками. mask (ArrayLike): Булева маска, где True - известные элементы. tol (float): Порог для критерия остановки (изменение ядерной нормы). max_iters (int): Максимальное количество итераций. r (float): Параметр rho для ADMM. backend (BackendModule): Вычислительный бэкенд для всех операций (модуль numpy или cupy). Returns: Tuple[ArrayLike, List[float]]: Кортеж, содержащий: - Восстановленную матрицу X (на том же бэкенде, что и входные данные). - Историю значений ядерной нормы (список чисел float). """ height, width = Y.shape # Инициализация переменных X = backend.random.rand(height, width) Z = backend.random.rand(height, width) lambda_ = backend.zeros_like(Y) # Начальная проекция, чтобы Z с самого начала удовлетворял ограничениям Z[mask] = Y[mask] # Инициализация истории для отслеживания сходимости u, s, v = backend.linalg.svd(X, compute_uv=True) norm_prev = backend.sum(s).get().item() if hasattr(s, "get") else backend.sum(s).item() norms_history = [norm_prev] for i in range(max_iters): # Шаг 1: Обновление X (SVT) X = MC_update_X(Z, lambda_, r, backend) # Шаг 2: Обновление Z (Проекция) Z = MC_update_Z(X, lambda_, r, Y, mask) # Шаг 3: Обновление двойственной переменной lambda_ += (X - Z) * r # Проверка сходимости по изменению ядерной нормы u, s, v = backend.linalg.svd(X, compute_uv=True) norm_current = backend.sum(s).get().item() if hasattr(s, "get") else backend.sum(s).item() norms_history.append(norm_current) # abs() для обычных чисел Python, np.abs() для массивов if abs(norm_current - norm_prev) < tol: print(f"Сходимость достигнута на итерации {i + 1}.") break norm_prev = norm_current else: print("Достигнуто максимальное количество итераций.") return X, norms_history