Files
XplorePlane/XP.ImageProcessing.Processors/图像增强/SuperResolutionProcessor.cs
T
2026-04-14 17:12:31 +08:00

319 lines
12 KiB
C#
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// ============================================================================
// Copyright © 2026 Hexagon Technology Center GmbH. All Rights Reserved.
// 文件名: SuperResolutionProcessor.cs
// 描述: 基于深度学习的超分辨率算子
// 功能:
// - 支持 EDSR 和 FSRCNN 超分辨率模型(ONNX 格式)
// - 支持 2x、3x、4x 放大倍率
// - 灰度图像自动转换为三通道输入,推理后转回灰度
// - 模型文件自动搜索,支持自定义路径
// - 使用 Microsoft.ML.OnnxRuntime 进行推理
// 算法: EDSR (Enhanced Deep Residual SR) / FSRCNN (Fast SR CNN)
// 作者: 李伟 wei.lw.li@hexagon.com
// ============================================================================
using Emgu.CV;
using Emgu.CV.CvEnum;
using Emgu.CV.Structure;
using XP.ImageProcessing.Core;
using Microsoft.ML.OnnxRuntime;
using Microsoft.ML.OnnxRuntime.Tensors;
using Serilog;
using System.IO;
namespace XP.ImageProcessing.Processors;
/// <summary>
/// 基于深度学习的超分辨率算子(EDSR / FSRCNN),使用 ONNX Runtime 推理
/// </summary>
public class SuperResolutionProcessor : ImageProcessorBase
{
private static readonly ILogger _logger = Log.ForContext<SuperResolutionProcessor>();
// 会话缓存,避免重复加载
private static InferenceSession? _cachedSession;
private static string _cachedModelKey = string.Empty;
public SuperResolutionProcessor()
{
Name = LocalizationHelper.GetString("SuperResolutionProcessor_Name");
Description = LocalizationHelper.GetString("SuperResolutionProcessor_Description");
}
protected override void InitializeParameters()
{
Parameters.Add("Model", new ProcessorParameter(
"Model",
LocalizationHelper.GetString("SuperResolutionProcessor_Model"),
typeof(string),
"FSRCNN",
null,
null,
LocalizationHelper.GetString("SuperResolutionProcessor_Model_Desc"),
new string[] { "EDSR", "FSRCNN" }));
Parameters.Add("Scale", new ProcessorParameter(
"Scale",
LocalizationHelper.GetString("SuperResolutionProcessor_Scale"),
typeof(string),
"2",
null,
null,
LocalizationHelper.GetString("SuperResolutionProcessor_Scale_Desc"),
new string[] { "2", "3", "4" }));
_logger.Debug("InitializeParameters");
}
public override Image<Gray, byte> Process(Image<Gray, byte> inputImage)
{
string model = GetParameter<string>("Model");
int scale = int.Parse(GetParameter<string>("Scale"));
// 查找模型文件
string modelPath = FindModelFile(model, scale);
if (string.IsNullOrEmpty(modelPath))
{
_logger.Error("Model file not found: {Model}_x{Scale}.onnx", model, scale);
throw new FileNotFoundException(
$"超分辨率模型文件未找到: {model}_x{scale}.onnx\n" +
$"请将模型文件放置到以下任一目录:\n" +
$" 1. 程序目录/Models/\n" +
$" 2. 程序目录/\n" +
$"模型需要 ONNX 格式。\n" +
$"可使用 tf2onnx 从 .pb 转换:\n" +
$" pip install tf2onnx\n" +
$" python -m tf2onnx.convert --input {model}_x{scale}.pb --output {model}_x{scale}.onnx --inputs input:0 --outputs output:0");
}
// 加载或复用会话
string modelKey = $"{model}_{scale}";
InferenceSession session;
if (_cachedModelKey == modelKey && _cachedSession != null)
{
session = _cachedSession;
_logger.Debug("Reusing cached session: {ModelKey}", modelKey);
}
else
{
_cachedSession?.Dispose();
var options = new SessionOptions();
options.GraphOptimizationLevel = GraphOptimizationLevel.ORT_ENABLE_ALL;
try
{
options.AppendExecutionProvider_CUDA(0);
_logger.Information("Using CUDA GPU for inference");
}
catch
{
_logger.Warning("CUDA not available, falling back to CPU");
}
session = new InferenceSession(modelPath, options);
_cachedSession = session;
_cachedModelKey = modelKey;
// 记录实际使用的 Execution Provider
var providers = session.ModelMetadata?.CustomMetadataMap;
_logger.Information("Loaded ONNX model: {ModelPath}, Providers: {Providers}",
modelPath, string.Join(", ", session.GetType().Name));
}
int h = inputImage.Height;
int w = inputImage.Width;
_logger.Information("Input image size: {W}x{H}, Model: {Model}, Scale: {Scale}", w, h, model, scale);
// 对大图使用分块推理策略,避免单次推理过慢/OOM
const int TileSize = 256;
bool useTiling = (model.StartsWith("EDSR", StringComparison.OrdinalIgnoreCase)) && (h > TileSize || w > TileSize);
if (useTiling)
{
return ProcessTiled(session, inputImage, scale, TileSize);
}
return ProcessSingle(session, inputImage, scale);
}
/// <summary>
/// 单次推理(小图或 FSRCNN)
/// </summary>
private Image<Gray, byte> ProcessSingle(InferenceSession session, Image<Gray, byte> inputImage, int scale)
{
int h = inputImage.Height;
int w = inputImage.Width;
// 获取模型输入信息
string inputName = session.InputMetadata.Keys.First();
var inputMeta = session.InputMetadata[inputName];
int[] dims = inputMeta.Dimensions;
// dims 格式: [1, H, W, C] (NHWC)C 可能是 1 或 3
int inputChannels = dims[^1]; // 最后一维是通道数
// 构建输入 tensor: [1, H, W, C] (NHWC)
// 使用底层数组 + Parallel.For 避免逐元素索引开销
DenseTensor<float> inputTensor;
if (inputChannels == 1)
{
// FSRCNN: 单通道灰度输入
inputTensor = new DenseTensor<float>(new[] { 1, h, w, 1 });
float[] buf = inputTensor.Buffer.ToArray();
var imgData = inputImage.Data;
Parallel.For(0, h, y =>
{
int rowOffset = y * w;
for (int x = 0; x < w; x++)
buf[rowOffset + x] = imgData[y, x, 0];
});
inputTensor = new DenseTensor<float>(buf, new[] { 1, h, w, 1 });
}
else
{
// EDSR: 三通道 BGR 输入
using var colorInput = new Image<Bgr, byte>(w, h);
CvInvoke.CvtColor(inputImage, colorInput, ColorConversion.Gray2Bgr);
var buf = new float[h * w * 3];
var imgData = colorInput.Data;
Parallel.For(0, h, y =>
{
int rowOffset = y * w * 3;
for (int x = 0; x < w; x++)
{
int px = rowOffset + x * 3;
buf[px] = imgData[y, x, 0];
buf[px + 1] = imgData[y, x, 1];
buf[px + 2] = imgData[y, x, 2];
}
});
inputTensor = new DenseTensor<float>(buf, new[] { 1, h, w, 3 });
}
// 推理
var inputs = new List<NamedOnnxValue>
{
NamedOnnxValue.CreateFromTensor(inputName, inputTensor)
};
using var results = session.Run(inputs);
var outputTensor = results.First().AsTensor<float>();
// 输出 shape: [1, C, H*scale, W*scale] (NCHW,模型输出经过 Transpose)
var shape = outputTensor.Dimensions;
int outC = shape[1];
int outH = shape[2];
int outW = shape[3];
// 转换为灰度图像
// 使用 Parallel.For + 直接内存操作
Image<Gray, byte> result;
if (outC == 1)
{
// FSRCNN: 单通道输出 [1, 1, outH, outW]
result = new Image<Gray, byte>(outW, outH);
var outData = result.Data;
Parallel.For(0, outH, y =>
{
for (int x = 0; x < outW; x++)
outData[y, x, 0] = (byte)Math.Clamp((int)outputTensor[0, 0, y, x], 0, 255);
});
}
else
{
// EDSR: 三通道输出 [1, 3, outH, outW] → 灰度
// 直接计算灰度值,跳过中间 BGR 图像分配
result = new Image<Gray, byte>(outW, outH);
var outData = result.Data;
Parallel.For(0, outH, y =>
{
for (int x = 0; x < outW; x++)
{
float b = outputTensor[0, 0, y, x];
float g = outputTensor[0, 1, y, x];
float r = outputTensor[0, 2, y, x];
// BT.601 灰度公式: 0.299*R + 0.587*G + 0.114*B
int gray = (int)(0.299f * r + 0.587f * g + 0.114f * b);
outData[y, x, 0] = (byte)Math.Clamp(gray, 0, 255);
}
});
}
_logger.Debug("ProcessSingle: Scale={Scale}, Output={W}x{H}", scale, outW, outH);
return result;
}
/// <summary>
/// 分块推理(大图 EDSR),将图像切成小块分别推理后拼接
/// </summary>
private Image<Gray, byte> ProcessTiled(InferenceSession session, Image<Gray, byte> inputImage, int scale, int tileSize)
{
int h = inputImage.Height;
int w = inputImage.Width;
int overlap = 8; // 重叠像素,减少拼接边缘伪影
var result = new Image<Gray, byte>(w * scale, h * scale);
int tilesX = (int)Math.Ceiling((double)w / (tileSize - overlap));
int tilesY = (int)Math.Ceiling((double)h / (tileSize - overlap));
_logger.Information("Tiled processing: {TilesX}x{TilesY} tiles, tileSize={TileSize}", tilesX, tilesY, tileSize);
for (int ty = 0; ty < tilesY; ty++)
{
for (int tx = 0; tx < tilesX; tx++)
{
int srcX = Math.Min(tx * (tileSize - overlap), w - tileSize);
int srcY = Math.Min(ty * (tileSize - overlap), h - tileSize);
srcX = Math.Max(srcX, 0);
srcY = Math.Max(srcY, 0);
int tw = Math.Min(tileSize, w - srcX);
int th = Math.Min(tileSize, h - srcY);
// 裁剪 tile
inputImage.ROI = new System.Drawing.Rectangle(srcX, srcY, tw, th);
var tile = inputImage.Copy();
inputImage.ROI = System.Drawing.Rectangle.Empty;
// 推理单个 tile
var srTile = ProcessSingle(session, tile, scale);
tile.Dispose();
// 写入结果
int dstX = srcX * scale;
int dstY = srcY * scale;
result.ROI = new System.Drawing.Rectangle(dstX, dstY, srTile.Width, srTile.Height);
srTile.CopyTo(result);
result.ROI = System.Drawing.Rectangle.Empty;
srTile.Dispose();
}
}
_logger.Debug("ProcessTiled: Scale={Scale}, Output={W}x{H}", scale, result.Width, result.Height);
return result;
}
/// <summary>
/// 查找模型文件,按优先级搜索多个目录(.onnx 格式)
/// </summary>
private static string FindModelFile(string model, int scale)
{
string baseDir = AppDomain.CurrentDomain.BaseDirectory;
string fileName = $"{model}_x{scale}.onnx";
string[] searchPaths = new[]
{
Path.Combine(baseDir, "Models", fileName),
Path.Combine(baseDir, fileName),
Path.Combine(Directory.GetCurrentDirectory(), "Models", fileName),
Path.Combine(Directory.GetCurrentDirectory(), fileName),
};
foreach (var path in searchPaths)
{
if (File.Exists(path))
{
_logger.Debug("Found model file: {Path}", path);
return path;
}
}
_logger.Warning("Model file not found: {Model}_x{Scale}.onnx", model, scale);
return string.Empty;
}
}