Исходный код denoise_bot.ML_component.main_model

import os
from typing import IO, Optional, Union

import numpy as np

from . import admm_core, utils, visualize


[документация] def create_mask_from_damaged(damaged_image: np.ndarray, threshold: int = 10) -> np.ndarray: """ Отвечает за создание маски для матрицы. Создает маску известных пикселей на основе поврежденного изображения (работает с NumPy массивами). Пиксель считается известным, если он не является почти черным. Args: damaged_image (np.ndarray): Входное изображение в формате NumPy (H, W, C). threshold (int): Порог яркости для определения черного пикселя (в диапазоне 0-255). Returns: np.ndarray: Двумерная булева маска, где True означает известный (неповрежденный) пиксель. """ # Суммируем значения каналов, чтобы получить одноканальное изображение gray_image = damaged_image.sum(axis=2) # Определяем порог в зависимости от нормализации # Если значения в диапазоне [0, 1], порог тоже должен быть в этом диапазоне is_normalized = damaged_image.max() <= 1.0 effective_threshold = threshold / 255.0 if is_normalized else threshold # True там, где пиксель не черный return gray_image > effective_threshold
[документация] def run_inpainting_pipeline( damaged_image_source: Union[str, IO[bytes]], output_dir: str, max_iters: int = 250, original_image_source: Optional[Union[str, IO[bytes]]] = None, use_gpu: bool = True, ) -> Optional[np.ndarray]: """ Главная функция для запуска всего процесса восстановления изображения (inpainting). Оркестрирует загрузку данных, выбор вычислительного бэкенда (CPU/GPU), запуск алгоритма ADMM для матричного дозаполнения и сохранение результатов. Может принимать на вход как пути к файлам, так и байтовые потоки изображений. Args: damaged_image_source (Union[str, IO[bytes]]): Источник поврежденного изображения. Может быть строкой с путем к файлу или файлоподобным объектом с байтами (например, io.BytesIO из Telegram). output_dir (str): Путь к папке, в которую будут сохранены все результаты (восстановленное изображение, графики). Папка будет создана, если не существует. max_iters (int, optional): Максимальное количество итераций для алгоритма ADMM. По умолчанию 250. original_image_source (Optional[Union[str, IO[bytes]]], optional): Источник оригинального (неповрежденного) изображения для создания сравнительной визуализации. Если не указан (None), сравнительный график создаваться не будет. По умолчанию None. use_gpu (bool, optional): Флаг для использования GPU. Если True, будет предпринята попытка использовать CuPy для вычислений. При его отсутствии или ошибке произойдет автоматический откат к NumPy (CPU). По умолчанию True. Returns: Optional[np.ndarray]: Восстановленное изображение в виде NumPy массива (на CPU) в случае успешного выполнения. Возвращает None, если в процессе произошла ошибка. """ # --- 1. Выбор бэкенда и настройка путей --- backend = utils.get_backend(use_gpu) os.makedirs(output_dir, exist_ok=True) # Формируем пути для сохранения результатов RECOVERED_IMAGE_PATH = os.path.join(output_dir, "recovered.png") COMPARISON_PLOT_PATH = os.path.join(output_dir, "comparison.png") CONVERGENCE_PLOT_PATH = os.path.join(output_dir, "convergence.png") # Параметры ADMM ADMM_TOL = 1e-3 ADMM_RHO = 0.5 try: # --- 2. Загрузка данных (всегда на CPU) --- print("Загрузка поврежденного изображения...") damaged_image_np = utils.load_image(damaged_image_source) # --- 3. Создание маски и перенос на выбранный бэкенд --- print("Создание маски на основе поврежденного изображения...") mask_np = create_mask_from_damaged(damaged_image_np) print(f"Передача данных на выбранный бэкенд ({'GPU' if backend is not np else 'CPU'})...") image_to_recover = utils.as_backend_array(damaged_image_np, backend) mask = utils.as_backend_array(mask_np, backend) # --- 4. Обработка каждого канала --- recovered_channels = [] norm_histories = [] channel_names = ["Красный", "Зелёный", "Синий"] for i in range(len(channel_names)): print(f"\n--- Восстановление канала: {channel_names[i]} ---") channel_data = image_to_recover[:, :, i] recovered_channel, history = admm_core.MC_ADMM( Y=channel_data, mask=mask, tol=ADMM_TOL, max_iters=max_iters, r=ADMM_RHO, backend=backend, ) recovered_channels.append(recovered_channel) norm_histories.append(history) # --- 5. Сборка результата и возврат на CPU --- print("\nСборка восстановленного изображения...") recovered_image = backend.stack(recovered_channels, axis=2) # --- 6. Сохранение и визуализация --- print("Сохранение результатов...") # Конвертируем финальный результат в NumPy для сохранения и визуализации recovered_image_np = utils.as_numpy(recovered_image) utils.save_image(recovered_image_np, RECOVERED_IMAGE_PATH) visualize.save_convergence_plot( norm_histories, channel_names, ["red", "green", "blue"], CONVERGENCE_PLOT_PATH, ) if original_image_source: print("Загрузка оригинального изображения для сравнения...") original_image_np = utils.load_image(original_image_source) visualize.save_results_comparison( original_image_np, damaged_image_np, recovered_image_np, COMPARISON_PLOT_PATH, ) print(f"\nВосстановление завершено. Результаты сохранены в папке: {output_dir}") return recovered_image_np except FileNotFoundError as e: print(f"Критическая ошибка: {e}") except Exception as e: print(f"Произошла непредвиденная ошибка: {e}") return None
# if __name__ == '__main__': # Вызов для тестирования # print("--- Запуск в тестовом режиме ---") # # script_dir = os.path.dirname(os.path.abspath(__file__)) # damaged_path = os.path.join(script_dir, 'data', 'damaged.png') # original_path = os.path.join(script_dir, 'data', 'img.png') # output_path = os.path.join(script_dir, 'output_test') # # run_inpainting_pipeline( # damaged_image_path=damaged_path, # output_dir=output_path, # max_iters=250, # original_image_path=original_path, # use_gpu=True # Попытаться использовать GPU, если доступен # )