using AmtScanner.Api.Data; using AmtScanner.Api.Models; using Microsoft.EntityFrameworkCore; using System.Collections.Concurrent; using System.Management; using System.Net.NetworkInformation; using System.Net.Sockets; namespace AmtScanner.Api.Services; public interface IWindowsScannerService { Task> ScanNetworkAsync(string taskId, string networkSegment, string subnetMask, IProgress progress, CancellationToken cancellationToken = default); Task GetOsInfoAsync(string ipAddress, string username, string password); Task GetSystemUuidAsync(string ipAddress, string username, string password); Task BindAmtDevicesAsync(); void CancelScan(string taskId); } public class WindowsScannerService : IWindowsScannerService { private readonly IServiceScopeFactory _scopeFactory; private readonly ILogger _logger; private readonly IConfiguration _configuration; private readonly ConcurrentDictionary _cancellationTokens = new(); public WindowsScannerService( IServiceScopeFactory scopeFactory, ILogger logger, IConfiguration configuration) { _scopeFactory = scopeFactory; _logger = logger; _configuration = configuration; } public async Task> ScanNetworkAsync( string taskId, string networkSegment, string subnetMask, IProgress progress, CancellationToken cancellationToken = default) { _logger.LogInformation("Starting OS scan for task: {TaskId}", taskId); var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); _cancellationTokens[taskId] = cts; try { var ipList = CalculateIpRange(networkSegment, subnetMask); var foundDevices = new ConcurrentBag(); int scannedCount = 0; int foundCount = 0; var threadPoolSize = _configuration.GetValue("Scanner:ThreadPoolSize", 50); var parallelOptions = new ParallelOptions { MaxDegreeOfParallelism = threadPoolSize, CancellationToken = cts.Token }; await Parallel.ForEachAsync(ipList, parallelOptions, async (ip, ct) => { try { var device = await ScanSingleHostAsync(ip, ct); var scanned = Interlocked.Increment(ref scannedCount); if (device != null) { foundDevices.Add(device); var found = Interlocked.Increment(ref foundCount); await SaveOsDeviceAsync(device); progress.Report(new OsScanProgress { TaskId = taskId, ScannedCount = scanned, TotalCount = ipList.Count, FoundDevices = found, ProgressPercentage = (double)scanned / ipList.Count * 100, CurrentIp = ip, LatestDevice = device }); } else { progress.Report(new OsScanProgress { TaskId = taskId, ScannedCount = scanned, TotalCount = ipList.Count, FoundDevices = foundCount, ProgressPercentage = (double)scanned / ipList.Count * 100, CurrentIp = ip }); } } catch (Exception ex) { _logger.LogDebug(ex, "Error scanning {Ip}", ip); } }); // 扫描完成后尝试绑定 AMT 设备 await BindAmtDevicesAsync(); return foundDevices.ToList(); } finally { _cancellationTokens.TryRemove(taskId, out _); cts.Dispose(); } } public void CancelScan(string taskId) { if (_cancellationTokens.TryGetValue(taskId, out var cts)) { cts.Cancel(); _logger.LogInformation("OS scan task {TaskId} cancelled", taskId); } } private async Task ScanSingleHostAsync(string ip, CancellationToken ct) { // 先 Ping 检测是否在线 if (!await IsHostOnlineAsync(ip, ct)) return null; // 检测 Windows 端口 var isWindows = await IsWindowsHostAsync(ip, ct); if (isWindows) { return new OsDevice { IpAddress = ip, OsType = OsType.Windows, IsOnline = true, LastOnlineAt = DateTime.UtcNow, DiscoveredAt = DateTime.UtcNow, LastUpdatedAt = DateTime.UtcNow, Description = "通过端口扫描发现" }; } // 检测 Linux (SSH 端口) var isLinux = await IsPortOpenAsync(ip, 22, 2000, ct); if (isLinux) { return new OsDevice { IpAddress = ip, OsType = OsType.Linux, IsOnline = true, LastOnlineAt = DateTime.UtcNow, DiscoveredAt = DateTime.UtcNow, LastUpdatedAt = DateTime.UtcNow, Description = "通过 SSH 端口发现" }; } return null; } private async Task IsHostOnlineAsync(string ip, CancellationToken ct) { try { using var ping = new Ping(); var reply = await ping.SendPingAsync(ip, 1000); return reply.Status == IPStatus.Success; } catch { return false; } } private async Task IsWindowsHostAsync(string ip, CancellationToken ct) { // 检测 Windows 常用端口: 135(RPC), 445(SMB), 3389(RDP), 5985(WinRM) var windowsPorts = new[] { 135, 445, 3389, 5985 }; foreach (var port in windowsPorts) { if (await IsPortOpenAsync(ip, port, 1000, ct)) return true; } return false; } private async Task IsPortOpenAsync(string ip, int port, int timeoutMs, CancellationToken ct) { try { using var client = new TcpClient(); using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct); cts.CancelAfter(timeoutMs); await client.ConnectAsync(ip, port, cts.Token); return true; } catch { return false; } } /// /// 通过 WMI 获取远程 Windows 系统信息 /// public async Task GetOsInfoAsync(string ipAddress, string username, string password) { return await Task.Run(() => { try { var options = new ConnectionOptions { Username = username, Password = password, Impersonation = ImpersonationLevel.Impersonate, Authentication = AuthenticationLevel.PacketPrivacy }; var scope = new ManagementScope($"\\\\{ipAddress}\\root\\cimv2", options); scope.Connect(); var device = new OsDevice { IpAddress = ipAddress, OsType = OsType.Windows, IsOnline = true, LastOnlineAt = DateTime.UtcNow, LastUpdatedAt = DateTime.UtcNow }; // 获取 UUID var uuidQuery = new ObjectQuery("SELECT UUID FROM Win32_ComputerSystemProduct"); using (var uuidSearcher = new ManagementObjectSearcher(scope, uuidQuery)) { foreach (var obj in uuidSearcher.Get()) { device.SystemUuid = obj["UUID"]?.ToString(); break; } } // 获取操作系统信息 var osQuery = new ObjectQuery("SELECT Caption, Version, OSArchitecture, LastBootUpTime FROM Win32_OperatingSystem"); using (var osSearcher = new ManagementObjectSearcher(scope, osQuery)) { foreach (var obj in osSearcher.Get()) { device.OsVersion = $"{obj["Caption"]} ({obj["Version"]})"; device.Architecture = obj["OSArchitecture"]?.ToString(); var lastBootStr = obj["LastBootUpTime"]?.ToString(); if (!string.IsNullOrEmpty(lastBootStr)) { device.LastBootTime = ManagementDateTimeConverter.ToDateTime(lastBootStr); } break; } } // 获取计算机名 var csQuery = new ObjectQuery("SELECT Name, UserName FROM Win32_ComputerSystem"); using (var csSearcher = new ManagementObjectSearcher(scope, csQuery)) { foreach (var obj in csSearcher.Get()) { device.Hostname = obj["Name"]?.ToString(); device.LoggedInUser = obj["UserName"]?.ToString(); break; } } // 获取 MAC 地址 var netQuery = new ObjectQuery("SELECT MACAddress FROM Win32_NetworkAdapterConfiguration WHERE IPEnabled = True"); using (var netSearcher = new ManagementObjectSearcher(scope, netQuery)) { foreach (var obj in netSearcher.Get()) { var mac = obj["MACAddress"]?.ToString(); if (!string.IsNullOrEmpty(mac)) { device.MacAddress = mac; break; } } } device.Description = "通过 WMI 获取详细信息"; return device; } catch (Exception ex) { _logger.LogWarning(ex, "Failed to get OS info for {Ip} via WMI", ipAddress); return null; } }); } /// /// 获取远程 Windows 系统的 UUID /// public async Task GetSystemUuidAsync(string ipAddress, string username, string password) { return await Task.Run(() => { try { var options = new ConnectionOptions { Username = username, Password = password, Impersonation = ImpersonationLevel.Impersonate, Authentication = AuthenticationLevel.PacketPrivacy }; var scope = new ManagementScope($"\\\\{ipAddress}\\root\\cimv2", options); scope.Connect(); var query = new ObjectQuery("SELECT UUID FROM Win32_ComputerSystemProduct"); using var searcher = new ManagementObjectSearcher(scope, query); foreach (var obj in searcher.Get()) { return obj["UUID"]?.ToString(); } return null; } catch (Exception ex) { _logger.LogWarning(ex, "Failed to get UUID for {Ip}", ipAddress); return null; } }); } /// /// 根据 UUID 自动绑定 AMT 设备和操作系统设备 /// public async Task BindAmtDevicesAsync() { using var scope = _scopeFactory.CreateScope(); var context = scope.ServiceProvider.GetRequiredService(); // 获取所有有 UUID 的操作系统设备 var osDevices = await context.OsDevices .Where(o => o.SystemUuid != null && o.AmtDeviceId == null) .ToListAsync(); // 获取所有有 UUID 的 AMT 设备 var amtDevices = await context.AmtDevices .Where(a => a.SystemUuid != null) .ToListAsync(); var amtUuidMap = amtDevices.ToDictionary(a => a.SystemUuid!, a => a); foreach (var osDevice in osDevices) { if (osDevice.SystemUuid != null && amtUuidMap.TryGetValue(osDevice.SystemUuid, out var amtDevice)) { osDevice.AmtDeviceId = amtDevice.Id; _logger.LogInformation("Bound OS device {OsIp} to AMT device {AmtIp} via UUID {Uuid}", osDevice.IpAddress, amtDevice.IpAddress, osDevice.SystemUuid); } } await context.SaveChangesAsync(); } private async Task SaveOsDeviceAsync(OsDevice device) { using var scope = _scopeFactory.CreateScope(); var context = scope.ServiceProvider.GetRequiredService(); var existing = await context.OsDevices .FirstOrDefaultAsync(d => d.IpAddress == device.IpAddress); if (existing != null) { existing.OsType = device.OsType; existing.IsOnline = device.IsOnline; existing.LastOnlineAt = device.LastOnlineAt; existing.LastUpdatedAt = DateTime.UtcNow; if (!string.IsNullOrEmpty(device.SystemUuid)) existing.SystemUuid = device.SystemUuid; if (!string.IsNullOrEmpty(device.Hostname)) existing.Hostname = device.Hostname; if (!string.IsNullOrEmpty(device.OsVersion)) existing.OsVersion = device.OsVersion; } else { context.OsDevices.Add(device); } await context.SaveChangesAsync(); } private List CalculateIpRange(string networkSegment, string subnetMask) { var ipList = new List(); try { var networkLong = IpToLong(networkSegment); var cidr = SubnetMaskToCidr(subnetMask); var hostBits = 32 - cidr; var totalHosts = (int)Math.Pow(2, hostBits); for (int i = 1; i < totalHosts - 1; i++) { var ipLong = networkLong + i; ipList.Add(LongToIp(ipLong)); } } catch (Exception ex) { _logger.LogError(ex, "Error calculating IP range"); } return ipList; } private long IpToLong(string ipAddress) { var parts = ipAddress.Split('.'); long result = 0; for (int i = 0; i < 4; i++) result = result << 8 | long.Parse(parts[i]); return result; } private string LongToIp(long ip) => $"{(ip >> 24) & 0xFF}.{(ip >> 16) & 0xFF}.{(ip >> 8) & 0xFF}.{ip & 0xFF}"; private int SubnetMaskToCidr(string subnetMask) { if (subnetMask.StartsWith("/")) return int.Parse(subnetMask.Substring(1)); var parts = subnetMask.Split('.'); int cidr = 0; foreach (var part in parts) cidr += Convert.ToString(int.Parse(part), 2).Count(c => c == '1'); return cidr; } } public class OsScanProgress { public string TaskId { get; set; } = string.Empty; public int ScannedCount { get; set; } public int TotalCount { get; set; } public int FoundDevices { get; set; } public double ProgressPercentage { get; set; } public string? CurrentIp { get; set; } public OsDevice? LatestDevice { get; set; } }