155 lines
5.4 KiB
C#

using AmtScanner.Api.Models;
using AmtScanner.Api.Services;
using Microsoft.AspNetCore.Mvc;
using Microsoft.AspNetCore.SignalR;
using System.Collections.Concurrent;
namespace AmtScanner.Api.Controllers;
[ApiController]
[Route("api/[controller]")]
public class ScanController : ControllerBase
{
private readonly IAmtScannerService _scannerService;
private readonly IHubContext<ScanProgressHub> _hubContext;
private readonly ILogger<ScanController> _logger;
// 存储扫描进度状态
private static readonly ConcurrentDictionary<string, ScanStatusInfo> _scanStatuses = new();
public ScanController(
IAmtScannerService scannerService,
IHubContext<ScanProgressHub> hubContext,
ILogger<ScanController> logger)
{
_scannerService = scannerService;
_hubContext = hubContext;
_logger = logger;
}
[HttpPost("start")]
public async Task<ActionResult<ApiResponse<ScanStartResponse>>> StartScan([FromBody] ScanRequest request)
{
var taskId = Guid.NewGuid().ToString();
_logger.LogInformation("Starting scan task: {TaskId}", taskId);
// 初始化扫描状态
_scanStatuses[taskId] = new ScanStatusInfo
{
TaskId = taskId,
Status = "running",
ScannedCount = 0,
TotalCount = 0,
FoundDevices = 0
};
// 创建进度回调 - 直接使用 Action 而不是 Progress<T>
Action<ScanProgress> progressCallback = p =>
{
// 更新状态存储
if (_scanStatuses.TryGetValue(taskId, out var status))
{
status.ScannedCount = p.ScannedCount;
status.TotalCount = p.TotalCount;
status.FoundDevices = p.FoundDevices;
status.CurrentIp = p.CurrentIp;
_logger.LogInformation("Progress update: scanned={Scanned}, total={Total}, found={Found}, ip={Ip}",
p.ScannedCount, p.TotalCount, p.FoundDevices, p.CurrentIp);
}
// 异步发送 SignalR 通知(不等待)
_ = _hubContext.Clients.All.SendAsync("ReceiveScanProgress", p);
};
// Start scan in background
_ = Task.Run(async () =>
{
try
{
var foundDevicesList = await _scannerService.ScanNetworkAsync(
taskId,
request.NetworkSegment,
request.SubnetMask,
progressCallback
);
// 更新状态为完成,并确保 foundDevices 是正确的
if (_scanStatuses.TryGetValue(taskId, out var status))
{
status.Status = "completed";
status.FoundDevices = foundDevicesList.Count;
_logger.LogInformation("Scan task {TaskId} completed with {Count} devices found", taskId, foundDevicesList.Count);
}
// Send completion notification
await _hubContext.Clients.All.SendAsync("ScanCompleted", new { taskId });
}
catch (Exception ex)
{
// 更新状态为错误
if (_scanStatuses.TryGetValue(taskId, out var status))
{
status.Status = "error";
status.Error = ex.Message;
}
_logger.LogError(ex, "Error in scan task {TaskId}", taskId);
await _hubContext.Clients.All.SendAsync("ScanError", new { taskId, error = ex.Message });
}
});
return Ok(ApiResponse<ScanStartResponse>.Success(new ScanStartResponse { TaskId = taskId }, "扫描任务已启动"));
}
[HttpGet("status/{taskId}")]
public ActionResult<ApiResponse<ScanStatusInfo>> GetScanStatus(string taskId)
{
if (_scanStatuses.TryGetValue(taskId, out var status))
{
_logger.LogDebug("GetScanStatus: taskId={TaskId}, status={Status}, scanned={Scanned}, total={Total}, found={Found}",
taskId, status.Status, status.ScannedCount, status.TotalCount, status.FoundDevices);
return Ok(ApiResponse<ScanStatusInfo>.Success(status));
}
return Ok(ApiResponse<ScanStatusInfo>.Fail(404, "扫描任务不存在"));
}
[HttpPost("cancel/{taskId}")]
public ActionResult<ApiResponse<object>> CancelScan(string taskId)
{
_scannerService.CancelScan(taskId);
// 更新状态为已取消
if (_scanStatuses.TryGetValue(taskId, out var status))
{
status.Status = "cancelled";
}
return Ok(ApiResponse<object>.Success(null, "扫描任务已取消"));
}
}
public class ScanStartResponse
{
public string TaskId { get; set; } = string.Empty;
}
public class ScanStatusInfo
{
public string TaskId { get; set; } = string.Empty;
public string Status { get; set; } = "idle"; // idle, running, completed, cancelled, error
public int ScannedCount { get; set; }
public int TotalCount { get; set; }
public int FoundDevices { get; set; }
public string? CurrentIp { get; set; }
public string? Error { get; set; }
}
public class ScanRequest
{
public string NetworkSegment { get; set; } = string.Empty;
public string SubnetMask { get; set; } = string.Empty;
}