// ============================================================================
// Copyright © 2026 Hexagon Technology Center GmbH. All Rights Reserved.
// 文件? SuperResolutionProcessor.cs
// 描述: 基于深度学习的超分辨率算?
// 功能:
// - 支持 EDSR ?FSRCNN 超分辨率模型(ONNX 格式?
// - 支持 2x?x?x 放大倍率
// - 灰度图像自动转换为三通道输入,推理后转回灰度
// - 模型文件自动搜索,支持自定义路径
// - 使用 Microsoft.ML.OnnxRuntime 进行推理
// 算法: EDSR (Enhanced Deep Residual SR) / FSRCNN (Fast SR CNN)
// 作? 李伟 wei.lw.li@hexagon.com
// ============================================================================
using Emgu.CV;
using Emgu.CV.CvEnum;
using Emgu.CV.Structure;
using Microsoft.ML.OnnxRuntime;
using Microsoft.ML.OnnxRuntime.Tensors;
using Serilog;
using XP.ImageProcessing.Core;
namespace XP.ImageProcessing.Processors;
///
/// 基于深度学习的超分辨率算子(EDSR / FSRCNN),使用 ONNX Runtime 推理
///
public class SuperResolutionProcessor : ImageProcessorBase
{
private static readonly ILogger _logger = Log.ForContext();
// 会话缓存,避免重复加?
private static InferenceSession? _cachedSession;
private static string _cachedModelKey = string.Empty;
public SuperResolutionProcessor()
{
Name = LocalizationHelper.GetString("SuperResolutionProcessor_Name");
Description = LocalizationHelper.GetString("SuperResolutionProcessor_Description");
}
protected override void InitializeParameters()
{
Parameters.Add("Model", new ProcessorParameter(
"Model",
LocalizationHelper.GetString("SuperResolutionProcessor_Model"),
typeof(string),
"FSRCNN",
null,
null,
LocalizationHelper.GetString("SuperResolutionProcessor_Model_Desc"),
new string[] { "EDSR", "FSRCNN" }));
Parameters.Add("Scale", new ProcessorParameter(
"Scale",
LocalizationHelper.GetString("SuperResolutionProcessor_Scale"),
typeof(string),
"2",
null,
null,
LocalizationHelper.GetString("SuperResolutionProcessor_Scale_Desc"),
new string[] { "2", "3", "4" }));
_logger.Debug("InitializeParameters");
}
public override Image Process(Image inputImage)
{
string model = GetParameter("Model");
int scale = int.Parse(GetParameter("Scale"));
// 查找模型文件
string modelPath = FindModelFile(model, scale);
if (string.IsNullOrEmpty(modelPath))
{
_logger.Error("Model file not found: {Model}_x{Scale}.onnx", model, scale);
throw new FileNotFoundException(
$"超分辨率模型文件未找? {model}_x{scale}.onnx\n" +
$"请将模型文件放置到以下任一目录:\n" +
$" 1. 程序目录/Models/\n" +
$" 2. 程序目录/\n" +
$"模型需?ONNX 格式。\n" +
$"可使?tf2onnx ?.pb 转换:\n" +
$" pip install tf2onnx\n" +
$" python -m tf2onnx.convert --input {model}_x{scale}.pb --output {model}_x{scale}.onnx --inputs input:0 --outputs output:0");
}
// 加载或复用会?
string modelKey = $"{model}_{scale}";
InferenceSession session;
if (_cachedModelKey == modelKey && _cachedSession != null)
{
session = _cachedSession;
_logger.Debug("Reusing cached session: {ModelKey}", modelKey);
}
else
{
_cachedSession?.Dispose();
var options = new SessionOptions();
options.GraphOptimizationLevel = GraphOptimizationLevel.ORT_ENABLE_ALL;
try
{
options.AppendExecutionProvider_CUDA(0);
_logger.Information("Using CUDA GPU for inference");
}
catch
{
_logger.Warning("CUDA not available, falling back to CPU");
}
session = new InferenceSession(modelPath, options);
_cachedSession = session;
_cachedModelKey = modelKey;
// 记录实际使用?Execution Provider
var providers = session.ModelMetadata?.CustomMetadataMap;
_logger.Information("Loaded ONNX model: {ModelPath}, Providers: {Providers}",
modelPath, string.Join(", ", session.GetType().Name));
}
int h = inputImage.Height;
int w = inputImage.Width;
_logger.Information("Input image size: {W}x{H}, Model: {Model}, Scale: {Scale}", w, h, model, scale);
// 对大图使用分块推理策略,避免单次推理过慢/OOM
const int TileSize = 256;
bool useTiling = (model.StartsWith("EDSR", StringComparison.OrdinalIgnoreCase)) && (h > TileSize || w > TileSize);
if (useTiling)
{
return ProcessTiled(session, inputImage, scale, TileSize);
}
return ProcessSingle(session, inputImage, scale);
}
///
/// 单次推理(小图或 FSRCNN?
///
private Image ProcessSingle(InferenceSession session, Image inputImage, int scale)
{
int h = inputImage.Height;
int w = inputImage.Width;
// 获取模型输入信息
string inputName = session.InputMetadata.Keys.First();
var inputMeta = session.InputMetadata[inputName];
int[] dims = inputMeta.Dimensions;
// dims 格式: [1, H, W, C] (NHWC),C 可能?1 ?3
int inputChannels = dims[^1]; // 最后一维是通道?
// 构建输入 tensor: [1, H, W, C] (NHWC)
// 使用底层数组 + Parallel.For 避免逐元素索引开销
DenseTensor inputTensor;
if (inputChannels == 1)
{
// FSRCNN: 单通道灰度输入
inputTensor = new DenseTensor(new[] { 1, h, w, 1 });
float[] buf = inputTensor.Buffer.ToArray();
var imgData = inputImage.Data;
Parallel.For(0, h, y =>
{
int rowOffset = y * w;
for (int x = 0; x < w; x++)
buf[rowOffset + x] = imgData[y, x, 0];
});
inputTensor = new DenseTensor(buf, new[] { 1, h, w, 1 });
}
else
{
// EDSR: 三通道 BGR 输入
using var colorInput = new Image(w, h);
CvInvoke.CvtColor(inputImage, colorInput, ColorConversion.Gray2Bgr);
var buf = new float[h * w * 3];
var imgData = colorInput.Data;
Parallel.For(0, h, y =>
{
int rowOffset = y * w * 3;
for (int x = 0; x < w; x++)
{
int px = rowOffset + x * 3;
buf[px] = imgData[y, x, 0];
buf[px + 1] = imgData[y, x, 1];
buf[px + 2] = imgData[y, x, 2];
}
});
inputTensor = new DenseTensor(buf, new[] { 1, h, w, 3 });
}
// 推理
var inputs = new List
{
NamedOnnxValue.CreateFromTensor(inputName, inputTensor)
};
using var results = session.Run(inputs);
var outputTensor = results.First().AsTensor();
// 输出 shape: [1, C, H*scale, W*scale] (NCHW,模型输出经?Transpose)
var shape = outputTensor.Dimensions;
int outC = shape[1];
int outH = shape[2];
int outW = shape[3];
// 转换为灰度图?
// 使用 Parallel.For + 直接内存操作
Image result;
if (outC == 1)
{
// FSRCNN: 单通道输出 [1, 1, outH, outW]
result = new Image(outW, outH);
var outData = result.Data;
Parallel.For(0, outH, y =>
{
for (int x = 0; x < outW; x++)
outData[y, x, 0] = (byte)Math.Clamp((int)outputTensor[0, 0, y, x], 0, 255);
});
}
else
{
// EDSR: 三通道输出 [1, 3, outH, outW] ?灰度
// 直接计算灰度值,跳过中间 BGR 图像分配
result = new Image(outW, outH);
var outData = result.Data;
Parallel.For(0, outH, y =>
{
for (int x = 0; x < outW; x++)
{
float b = outputTensor[0, 0, y, x];
float g = outputTensor[0, 1, y, x];
float r = outputTensor[0, 2, y, x];
// BT.601 灰度公式: 0.299*R + 0.587*G + 0.114*B
int gray = (int)(0.299f * r + 0.587f * g + 0.114f * b);
outData[y, x, 0] = (byte)Math.Clamp(gray, 0, 255);
}
});
}
_logger.Debug("ProcessSingle: Scale={Scale}, Output={W}x{H}", scale, outW, outH);
return result;
}
///
/// 分块推理(大?EDSR),将图像切成小块分别推理后拼接
///
private Image ProcessTiled(InferenceSession session, Image inputImage, int scale, int tileSize)
{
int h = inputImage.Height;
int w = inputImage.Width;
int overlap = 8; // 重叠像素,减少拼接边缘伪?
var result = new Image(w * scale, h * scale);
int tilesX = (int)Math.Ceiling((double)w / (tileSize - overlap));
int tilesY = (int)Math.Ceiling((double)h / (tileSize - overlap));
_logger.Information("Tiled processing: {TilesX}x{TilesY} tiles, tileSize={TileSize}", tilesX, tilesY, tileSize);
for (int ty = 0; ty < tilesY; ty++)
{
for (int tx = 0; tx < tilesX; tx++)
{
int srcX = Math.Min(tx * (tileSize - overlap), w - tileSize);
int srcY = Math.Min(ty * (tileSize - overlap), h - tileSize);
srcX = Math.Max(srcX, 0);
srcY = Math.Max(srcY, 0);
int tw = Math.Min(tileSize, w - srcX);
int th = Math.Min(tileSize, h - srcY);
// 裁剪 tile
inputImage.ROI = new System.Drawing.Rectangle(srcX, srcY, tw, th);
var tile = inputImage.Copy();
inputImage.ROI = System.Drawing.Rectangle.Empty;
// 推理单个 tile
var srTile = ProcessSingle(session, tile, scale);
tile.Dispose();
// 写入结果
int dstX = srcX * scale;
int dstY = srcY * scale;
result.ROI = new System.Drawing.Rectangle(dstX, dstY, srTile.Width, srTile.Height);
srTile.CopyTo(result);
result.ROI = System.Drawing.Rectangle.Empty;
srTile.Dispose();
}
}
_logger.Debug("ProcessTiled: Scale={Scale}, Output={W}x{H}", scale, result.Width, result.Height);
return result;
}
///
/// 查找模型文件,按优先级搜索多个目录(.onnx 格式?
///
private static string FindModelFile(string model, int scale)
{
string baseDir = AppDomain.CurrentDomain.BaseDirectory;
string fileName = $"{model}_x{scale}.onnx";
string[] searchPaths = new[]
{
Path.Combine(baseDir, "Models", fileName),
Path.Combine(baseDir, fileName),
Path.Combine(Directory.GetCurrentDirectory(), "Models", fileName),
Path.Combine(Directory.GetCurrentDirectory(), fileName),
};
foreach (var path in searchPaths)
{
if (File.Exists(path))
{
_logger.Debug("Found model file: {Path}", path);
return path;
}
}
_logger.Warning("Model file not found: {Model}_x{Scale}.onnx", model, scale);
return string.Empty;
}
}