28a1ae0898

321 lines
12 KiB
C#
321 lines
12 KiB
C#
// ============================================================================
|
||
// Copyright © 2026 Hexagon Technology Center GmbH. All Rights Reserved.
|
||
// 文件名: SuperResolutionProcessor.cs
|
||
// 描述: 基于深度学习的超分辨率算子
|
||
// 功能:
|
||
// - 支持 EDSR 和 FSRCNN 超分辨率模型(ONNX 格式)
|
||
// - 支持 2x、3x、4x 放大倍率
|
||
// - 灰度图像自动转换为三通道输入,推理后转回灰度
|
||
// - 模型文件自动搜索,支持自定义路径
|
||
// - 使用 Microsoft.ML.OnnxRuntime 进行推理
|
||
// 算法: EDSR (Enhanced Deep Residual SR) / FSRCNN (Fast SR CNN)
|
||
// 作者: 李伟 wei.lw.li@hexagon.com
|
||
// ============================================================================
|
||
|
||
using Emgu.CV;
|
||
using Emgu.CV.CvEnum;
|
||
using Emgu.CV.Structure;
|
||
using XP.ImageProcessing.Core;
|
||
using Microsoft.ML.OnnxRuntime;
|
||
using Microsoft.ML.OnnxRuntime.Tensors;
|
||
using Serilog;
|
||
using System.IO;
|
||
|
||
namespace XP.ImageProcessing.Processors;
|
||
|
||
/// <summary>
|
||
/// 基于深度学习的超分辨率算子(EDSR / FSRCNN),使用 ONNX Runtime 推理
|
||
/// </summary>
|
||
public class SuperResolutionProcessor : ImageProcessorBase
|
||
{
|
||
private static readonly ILogger _logger = Log.ForContext<SuperResolutionProcessor>();
|
||
|
||
// 会话缓存,避免重复加载
|
||
private static InferenceSession? _cachedSession;
|
||
|
||
private static string _cachedModelKey = string.Empty;
|
||
private static readonly object _sessionLock = new();
|
||
|
||
public SuperResolutionProcessor()
|
||
{
|
||
Name = LocalizationHelper.GetString("SuperResolutionProcessor_Name");
|
||
Description = LocalizationHelper.GetString("SuperResolutionProcessor_Description");
|
||
}
|
||
|
||
protected override void InitializeParameters()
|
||
{
|
||
Parameters.Add("Model", new ProcessorParameter(
|
||
"Model",
|
||
LocalizationHelper.GetString("SuperResolutionProcessor_Model"),
|
||
typeof(string),
|
||
"FSRCNN",
|
||
null,
|
||
null,
|
||
LocalizationHelper.GetString("SuperResolutionProcessor_Model_Desc"),
|
||
new string[] { "EDSR", "FSRCNN" }));
|
||
|
||
Parameters.Add("Scale", new ProcessorParameter(
|
||
"Scale",
|
||
LocalizationHelper.GetString("SuperResolutionProcessor_Scale"),
|
||
typeof(string),
|
||
"2",
|
||
null,
|
||
null,
|
||
LocalizationHelper.GetString("SuperResolutionProcessor_Scale_Desc"),
|
||
new string[] { "2", "3", "4" }));
|
||
|
||
_logger.Debug("InitializeParameters");
|
||
}
|
||
|
||
public override Image<Gray, byte> Process(Image<Gray, byte> inputImage)
|
||
{
|
||
string model = GetParameter<string>("Model");
|
||
int scale = GetParameter<int>("Scale");
|
||
|
||
// 查找模型文件
|
||
string modelPath = FindModelFile(model, scale);
|
||
if (string.IsNullOrEmpty(modelPath))
|
||
{
|
||
_logger.Error("Model file not found: {Model}_x{Scale}.onnx", model, scale);
|
||
throw new FileNotFoundException(
|
||
$"超分辨率模型文件未找到: {model}_x{scale}.onnx\n" +
|
||
$"请将模型文件放置到以下任一目录:\n" +
|
||
$" 1. 程序目录/Models/\n" +
|
||
$" 2. 程序目录/\n" +
|
||
$"模型需要 ONNX 格式。\n" +
|
||
$"可使用 tf2onnx 从 .pb 转换:\n" +
|
||
$" pip install tf2onnx\n" +
|
||
$" python -m tf2onnx.convert --input {model}_x{scale}.pb --output {model}_x{scale}.onnx --inputs input:0 --outputs output:0");
|
||
}
|
||
|
||
// 加载或复用会话
|
||
string modelKey = $"{model}_{scale}";
|
||
InferenceSession session;
|
||
lock (_sessionLock)
|
||
{
|
||
if (_cachedModelKey == modelKey && _cachedSession != null)
|
||
{
|
||
session = _cachedSession;
|
||
_logger.Debug("Reusing cached session: {ModelKey}", modelKey);
|
||
}
|
||
else
|
||
{
|
||
_cachedSession?.Dispose();
|
||
var options = new SessionOptions();
|
||
options.GraphOptimizationLevel = GraphOptimizationLevel.ORT_ENABLE_ALL;
|
||
bool cudaEnabled = false;
|
||
try
|
||
{
|
||
options.AppendExecutionProvider_CUDA(0);
|
||
cudaEnabled = true;
|
||
}
|
||
catch (Exception ex)
|
||
{
|
||
_logger.Warning(ex, "CUDA EP unavailable (check CUDA/cuDNN version match), falling back to CPU");
|
||
}
|
||
session = new InferenceSession(modelPath, options);
|
||
_cachedSession = session;
|
||
_cachedModelKey = modelKey;
|
||
_logger.Information("Loaded ONNX model: {ModelPath}, CUDA={CudaEnabled}", modelPath, cudaEnabled);
|
||
}
|
||
}
|
||
|
||
int h = inputImage.Height;
|
||
int w = inputImage.Width;
|
||
_logger.Information("Input image size: {W}x{H}, Model: {Model}, Scale: {Scale}", w, h, model, scale);
|
||
|
||
// 对大图使用分块推理策略,避免单次推理过慢/OOM
|
||
const int TileSize = 256;
|
||
bool useTiling = (model.StartsWith("EDSR", StringComparison.OrdinalIgnoreCase)) && (h > TileSize || w > TileSize);
|
||
|
||
if (useTiling)
|
||
{
|
||
return ProcessTiled(session, inputImage, scale, TileSize);
|
||
}
|
||
|
||
return ProcessSingle(session, inputImage, scale);
|
||
}
|
||
|
||
/// <summary>
|
||
/// 单次推理(小图或 FSRCNN)
|
||
/// </summary>
|
||
private Image<Gray, byte> ProcessSingle(InferenceSession session, Image<Gray, byte> inputImage, int scale)
|
||
{
|
||
int h = inputImage.Height;
|
||
int w = inputImage.Width;
|
||
|
||
// 获取模型输入信息
|
||
string inputName = session.InputMetadata.Keys.First();
|
||
var inputMeta = session.InputMetadata[inputName];
|
||
int[] dims = inputMeta.Dimensions;
|
||
// dims 格式: [1, H, W, C] (NHWC),C 可能是 1 或 3
|
||
int inputChannels = dims[^1]; // 最后一维是通道数
|
||
|
||
// 构建输入 tensor: [1, H, W, C] (NHWC)
|
||
// 使用底层数组 + Parallel.For 避免逐元素索引开销
|
||
DenseTensor<float> inputTensor;
|
||
if (inputChannels == 1)
|
||
{
|
||
// FSRCNN: 单通道灰度输入
|
||
inputTensor = new DenseTensor<float>(new[] { 1, h, w, 1 });
|
||
float[] buf = inputTensor.Buffer.ToArray();
|
||
var imgData = inputImage.Data;
|
||
Parallel.For(0, h, y =>
|
||
{
|
||
int rowOffset = y * w;
|
||
for (int x = 0; x < w; x++)
|
||
buf[rowOffset + x] = imgData[y, x, 0];
|
||
});
|
||
inputTensor = new DenseTensor<float>(buf, new[] { 1, h, w, 1 });
|
||
}
|
||
else
|
||
{
|
||
// EDSR: 三通道 BGR 输入
|
||
using var colorInput = new Image<Bgr, byte>(w, h);
|
||
CvInvoke.CvtColor(inputImage, colorInput, ColorConversion.Gray2Bgr);
|
||
var buf = new float[h * w * 3];
|
||
var imgData = colorInput.Data;
|
||
Parallel.For(0, h, y =>
|
||
{
|
||
int rowOffset = y * w * 3;
|
||
for (int x = 0; x < w; x++)
|
||
{
|
||
int px = rowOffset + x * 3;
|
||
buf[px] = imgData[y, x, 0];
|
||
buf[px + 1] = imgData[y, x, 1];
|
||
buf[px + 2] = imgData[y, x, 2];
|
||
}
|
||
});
|
||
inputTensor = new DenseTensor<float>(buf, new[] { 1, h, w, 3 });
|
||
}
|
||
|
||
// 推理
|
||
var inputs = new List<NamedOnnxValue>
|
||
{
|
||
NamedOnnxValue.CreateFromTensor(inputName, inputTensor)
|
||
};
|
||
|
||
using var results = session.Run(inputs);
|
||
var outputTensor = (DenseTensor<float>)results.First().AsTensor<float>();
|
||
|
||
// 输出 shape: [1, C, H*scale, W*scale] (NCHW)
|
||
var shape = outputTensor.Dimensions;
|
||
int outC = shape[1];
|
||
int outH = shape[2];
|
||
int outW = shape[3];
|
||
var outBuf = outputTensor.Buffer.ToArray(); // Span 不能跨 lambda 捕获,转为数组
|
||
|
||
Image<Gray, byte> result = new(outW, outH);
|
||
var outData = result.Data;
|
||
int planeSize = outH * outW;
|
||
|
||
if (outC == 1)
|
||
{
|
||
// FSRCNN: [1, 1, outH, outW]
|
||
Parallel.For(0, outH, y =>
|
||
{
|
||
int rowOffset = y * outW;
|
||
for (int x = 0; x < outW; x++)
|
||
outData[y, x, 0] = (byte)Math.Clamp((int)outBuf[rowOffset + x], 0, 255);
|
||
});
|
||
}
|
||
else
|
||
{
|
||
// EDSR: [1, 3, outH, outW] → 灰度,BT.601
|
||
Parallel.For(0, outH, y =>
|
||
{
|
||
int rowOffset = y * outW;
|
||
for (int x = 0; x < outW; x++)
|
||
{
|
||
int i = rowOffset + x;
|
||
float b = outBuf[i];
|
||
float g = outBuf[planeSize + i];
|
||
float r = outBuf[planeSize * 2 + i];
|
||
outData[y, x, 0] = (byte)Math.Clamp((int)(0.299f * r + 0.587f * g + 0.114f * b), 0, 255);
|
||
}
|
||
});
|
||
}
|
||
|
||
_logger.Debug("ProcessSingle: Scale={Scale}, Output={W}x{H}", scale, outW, outH);
|
||
|
||
return result;
|
||
}
|
||
|
||
/// <summary>
|
||
/// 分块推理(大图 EDSR),将图像切成小块分别推理后拼接
|
||
/// </summary>
|
||
private Image<Gray, byte> ProcessTiled(InferenceSession session, Image<Gray, byte> inputImage, int scale, int tileSize)
|
||
{
|
||
int h = inputImage.Height;
|
||
int w = inputImage.Width;
|
||
int overlap = 8; // 重叠像素,减少拼接边缘伪影
|
||
|
||
var result = new Image<Gray, byte>(w * scale, h * scale);
|
||
|
||
int tilesX = (int)Math.Ceiling((double)w / (tileSize - overlap));
|
||
int tilesY = (int)Math.Ceiling((double)h / (tileSize - overlap));
|
||
_logger.Information("Tiled processing: {TilesX}x{TilesY} tiles, tileSize={TileSize}", tilesX, tilesY, tileSize);
|
||
|
||
for (int ty = 0; ty < tilesY; ty++)
|
||
{
|
||
for (int tx = 0; tx < tilesX; tx++)
|
||
{
|
||
int srcX = Math.Min(tx * (tileSize - overlap), w - tileSize);
|
||
int srcY = Math.Min(ty * (tileSize - overlap), h - tileSize);
|
||
srcX = Math.Max(srcX, 0);
|
||
srcY = Math.Max(srcY, 0);
|
||
int tw = Math.Min(tileSize, w - srcX);
|
||
int th = Math.Min(tileSize, h - srcY);
|
||
|
||
// 裁剪 tile
|
||
inputImage.ROI = new System.Drawing.Rectangle(srcX, srcY, tw, th);
|
||
var tile = inputImage.Copy();
|
||
inputImage.ROI = System.Drawing.Rectangle.Empty;
|
||
|
||
// 推理单个 tile
|
||
var srTile = ProcessSingle(session, tile, scale);
|
||
tile.Dispose();
|
||
|
||
// 写入结果
|
||
int dstX = srcX * scale;
|
||
int dstY = srcY * scale;
|
||
result.ROI = new System.Drawing.Rectangle(dstX, dstY, srTile.Width, srTile.Height);
|
||
srTile.CopyTo(result);
|
||
result.ROI = System.Drawing.Rectangle.Empty;
|
||
srTile.Dispose();
|
||
}
|
||
}
|
||
|
||
_logger.Debug("ProcessTiled: Scale={Scale}, Output={W}x{H}", scale, result.Width, result.Height);
|
||
return result;
|
||
}
|
||
|
||
/// <summary>
|
||
/// 查找模型文件,按优先级搜索多个目录(.onnx 格式)
|
||
/// </summary>
|
||
private static string FindModelFile(string model, int scale)
|
||
{
|
||
string baseDir = AppDomain.CurrentDomain.BaseDirectory;
|
||
string fileName = $"{model}_x{scale}.onnx";
|
||
string[] searchPaths = new[]
|
||
{
|
||
Path.Combine(baseDir, "Models", fileName),
|
||
Path.Combine(baseDir, fileName),
|
||
Path.Combine(Directory.GetCurrentDirectory(), "Models", fileName),
|
||
Path.Combine(Directory.GetCurrentDirectory(), fileName),
|
||
};
|
||
|
||
foreach (var path in searchPaths)
|
||
{
|
||
if (File.Exists(path))
|
||
{
|
||
_logger.Debug("Found model file: {Path}", path);
|
||
return path;
|
||
}
|
||
}
|
||
|
||
_logger.Warning("Model file not found: {Model}_x{Scale}.onnx", model, scale);
|
||
return string.Empty;
|
||
}
|
||
}
|