484 lines
17 KiB
C#

using AmtScanner.Api.Data;
using AmtScanner.Api.Models;
using Intel.Management.Wsman;
using Microsoft.EntityFrameworkCore;
using System.Collections.Concurrent;
using System.Net;
using System.Net.Sockets;
using System.Net.Security;
using System.Security;
using System.Security.Cryptography.X509Certificates;
namespace AmtScanner.Api.Services;
public class AmtScannerService : IAmtScannerService
{
private readonly IServiceScopeFactory _scopeFactory;
private readonly ILogger<AmtScannerService> _logger;
private readonly IConfiguration _configuration;
private readonly ConcurrentDictionary<string, CancellationTokenSource> _cancellationTokens = new();
public AmtScannerService(
IServiceScopeFactory scopeFactory,
ILogger<AmtScannerService> logger,
IConfiguration configuration)
{
_scopeFactory = scopeFactory;
_logger = logger;
_configuration = configuration;
}
public async Task<List<AmtDevice>> ScanNetworkAsync(
string taskId,
string networkSegment,
string subnetMask,
IProgress<ScanProgress> progress,
CancellationToken cancellationToken = default)
{
_logger.LogInformation("Starting network scan for task: {TaskId}", taskId);
var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
_cancellationTokens[taskId] = cts;
try
{
var ipList = CalculateIpRange(networkSegment, subnetMask);
_logger.LogInformation("Calculated {Count} IP addresses to scan", ipList.Count);
var foundDevices = new ConcurrentBag<AmtDevice>();
int scannedCount = 0;
int foundCount = 0;
var threadPoolSize = _configuration.GetValue<int>("Scanner:ThreadPoolSize", 100);
var parallelOptions = new ParallelOptions
{
MaxDegreeOfParallelism = threadPoolSize,
CancellationToken = cts.Token
};
await Parallel.ForEachAsync(ipList, parallelOptions, async (ip, ct) =>
{
try
{
var device = await ScanSingleDeviceAsync(ip, ct);
var scanned = Interlocked.Increment(ref scannedCount);
if (device != null)
{
foundDevices.Add(device);
var found = Interlocked.Increment(ref foundCount);
// Save to database
await SaveDeviceAsync(device);
progress.Report(new ScanProgress
{
TaskId = taskId,
ScannedCount = scanned,
TotalCount = ipList.Count,
FoundDevices = found,
ProgressPercentage = (double)scanned / ipList.Count * 100,
CurrentIp = ip,
LatestDevice = device
});
}
else
{
progress.Report(new ScanProgress
{
TaskId = taskId,
ScannedCount = scanned,
TotalCount = ipList.Count,
FoundDevices = foundCount,
ProgressPercentage = (double)scanned / ipList.Count * 100,
CurrentIp = ip
});
}
}
catch (Exception ex)
{
_logger.LogError(ex, "Error scanning {Ip}", ip);
}
});
_logger.LogInformation("Scan completed for task: {TaskId}. Found {Count} devices", taskId, foundDevices.Count);
return foundDevices.ToList();
}
finally
{
_cancellationTokens.TryRemove(taskId, out _);
cts.Dispose();
}
}
public void CancelScan(string taskId)
{
if (_cancellationTokens.TryGetValue(taskId, out var cts))
{
cts.Cancel();
_logger.LogInformation("Scan task {TaskId} cancelled", taskId);
}
}
private async Task<AmtDevice?> ScanSingleDeviceAsync(string ip, CancellationToken cancellationToken)
{
// Check if AMT ports are open
var openPorts = await GetOpenAmtPortsAsync(ip, cancellationToken);
if (openPorts.Count == 0)
{
return null;
}
_logger.LogInformation("Found AMT device at {Ip} with ports: {Ports}", ip, string.Join(", ", openPorts));
// Get credential in a separate scope to avoid DbContext disposal issues
AmtCredential? credential = null;
string? decryptedPassword = null;
using (var credScope = _scopeFactory.CreateScope())
{
var credentialService = credScope.ServiceProvider.GetRequiredService<ICredentialService>();
credential = await credentialService.GetDefaultCredentialAsync();
if (credential != null)
{
decryptedPassword = credentialService.DecryptPassword(credential.Password);
}
}
// Try to get device info using Intel SDK
if (credential != null && decryptedPassword != null)
{
try
{
var device = await GetAmtInfoUsingSDK(ip, credential.Username, decryptedPassword, openPorts);
if (device != null)
{
// Save to database using a new scope
await SaveDeviceAsync(device);
return device;
}
}
catch (Exception ex)
{
_logger.LogWarning(ex, "Failed to get AMT info for {Ip} using SDK", ip);
}
}
// Fallback: create device with basic info
var fallbackDevice = new AmtDevice
{
IpAddress = ip,
MajorVersion = 0,
MinorVersion = 0,
ProvisioningState = ProvisioningState.UNKNOWN,
Description = openPorts.Contains(16993) && !openPorts.Contains(16992)
? $"AMT device detected (HTTPS only). Open ports: {string.Join(", ", openPorts)}. Unable to query details via HTTPS."
: $"Detected via port scan. Open ports: {string.Join(", ", openPorts)}",
AmtOnline = true,
OsOnline = false,
DiscoveredAt = DateTime.UtcNow,
LastSeenAt = DateTime.UtcNow
};
await SaveDeviceAsync(fallbackDevice);
return fallbackDevice;
}
private async Task<AmtDevice?> GetAmtInfoUsingSDK(string ip, string username, string password, List<int> openPorts)
{
try
{
// Determine which protocol and port to use
string protocol;
int port;
if (openPorts.Contains(16992))
{
// Prefer HTTP if available (faster, no cert issues)
protocol = "http";
port = 16992;
}
else if (openPorts.Contains(16993))
{
// Use HTTPS if only that is available
protocol = "https";
port = 16993;
}
else
{
_logger.LogWarning("No suitable AMT port found for {Ip}. Open ports: {Ports}", ip, string.Join(", ", openPorts));
return null;
}
_logger.LogInformation("Connecting to AMT device at {Ip} using {Protocol}://{Ip}:{Port} with username: {Username}",
ip, protocol, ip, port, username);
// Convert password to SecureString
var securePassword = new SecureString();
foreach (char c in password)
{
securePassword.AppendChar(c);
}
securePassword.MakeReadOnly();
// Create WS-Management connection using Intel SDK
var connection = new WsmanConnection();
connection.Address = $"{protocol}://{ip}:{port}/wsman";
connection.SetCredentials(username, securePassword);
// Accept self-signed certificates for HTTPS connections
if (protocol == "https")
{
connection.Options.ServerCertificateValidationCallback = (certificate, sslPolicyErrors) =>
{
// If certificate is self-signed, ignore all errors
if (certificate.Subject.Equals(certificate.Issuer))
{
return true;
}
if (sslPolicyErrors == System.Net.Security.SslPolicyErrors.None)
{
return true;
}
return false;
};
}
// Get AMT version from CIM_SoftwareIdentity
string? version = null;
string? hostname = null;
int provisioningState = 0;
await Task.Run(() =>
{
try
{
// Query AMT version
var versionQuery = connection.ExecQuery("SELECT * FROM CIM_SoftwareIdentity WHERE InstanceID='AMT'");
foreach (Intel.Management.Wsman.IWsmanItem item in versionQuery)
{
var versionProp = item.Object.GetProperty("VersionString");
if (!versionProp.IsNull)
{
version = versionProp.ToString();
break;
}
}
// Query AMT general settings for hostname and provisioning state
var settingsQuery = connection.ExecQuery("SELECT * FROM AMT_GeneralSettings");
foreach (Intel.Management.Wsman.IWsmanItem item in settingsQuery)
{
var hostnameProp = item.Object.GetProperty("HostName");
if (!hostnameProp.IsNull)
{
hostname = hostnameProp.ToString();
}
break;
}
// Query provisioning state from AMT_SetupAndConfigurationService
var setupQuery = connection.ExecQuery("SELECT * FROM AMT_SetupAndConfigurationService");
foreach (Intel.Management.Wsman.IWsmanItem item in setupQuery)
{
var stateProp = item.Object.GetProperty("ProvisioningState");
if (!stateProp.IsNull)
{
provisioningState = int.Parse(stateProp.ToString());
}
break;
}
}
catch (Exception ex)
{
_logger.LogWarning(ex, "Error querying AMT info for {Ip}", ip);
}
});
if (version != null)
{
_logger.LogInformation("Successfully got AMT info for {Ip}: version={Version}, state={State}",
ip, version, provisioningState);
return new AmtDevice
{
IpAddress = ip,
Hostname = hostname ?? ip,
MajorVersion = ParseMajorVersion(version),
MinorVersion = ParseMinorVersion(version),
ProvisioningState = MapProvisioningStateFromInt(provisioningState),
Description = $"Detected via WS-Management ({protocol.ToUpper()}). Version: {version}",
AmtOnline = true,
OsOnline = false,
DiscoveredAt = DateTime.UtcNow,
LastSeenAt = DateTime.UtcNow
};
}
return null;
}
catch (Exception ex)
{
_logger.LogError(ex, "Error getting AMT info for {Ip}", ip);
return null;
}
}
private async Task<List<int>> GetOpenAmtPortsAsync(string ip, CancellationToken cancellationToken)
{
var ports = new[] { 16992, 16993, 623 };
var openPorts = new List<int>();
var timeout = _configuration.GetValue<int>("Scanner:TimeoutSeconds", 3) * 1000;
foreach (var port in ports)
{
if (await IsPortOpenAsync(ip, port, timeout, cancellationToken))
{
openPorts.Add(port);
}
}
return openPorts;
}
private async Task<bool> IsPortOpenAsync(string ip, int port, int timeoutMs, CancellationToken cancellationToken)
{
try
{
using var client = new TcpClient();
using var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
cts.CancelAfter(timeoutMs);
await client.ConnectAsync(ip, port, cts.Token);
return true;
}
catch
{
return false;
}
}
private List<string> CalculateIpRange(string networkSegment, string subnetMask)
{
var ipList = new List<string>();
try
{
var networkLong = IpToLong(networkSegment);
var cidr = SubnetMaskToCidr(subnetMask);
var hostBits = 32 - cidr;
var totalHosts = (int)Math.Pow(2, hostBits);
for (int i = 1; i < totalHosts - 1; i++)
{
var ipLong = networkLong + i;
ipList.Add(LongToIp(ipLong));
}
}
catch (Exception ex)
{
_logger.LogError(ex, "Error calculating IP range");
}
return ipList;
}
private long IpToLong(string ipAddress)
{
var parts = ipAddress.Split('.');
long result = 0;
for (int i = 0; i < 4; i++)
{
result = result << 8 | long.Parse(parts[i]);
}
return result;
}
private string LongToIp(long ip)
{
return $"{(ip >> 24) & 0xFF}.{(ip >> 16) & 0xFF}.{(ip >> 8) & 0xFF}.{ip & 0xFF}";
}
private int SubnetMaskToCidr(string subnetMask)
{
if (subnetMask.StartsWith("/"))
{
return int.Parse(subnetMask.Substring(1));
}
var parts = subnetMask.Split('.');
int cidr = 0;
foreach (var part in parts)
{
cidr += Convert.ToString(int.Parse(part), 2).Count(c => c == '1');
}
return cidr;
}
private async Task SaveDeviceAsync(AmtDevice device)
{
// Create a new scope for database operation
using var scope = _scopeFactory.CreateScope();
var context = scope.ServiceProvider.GetRequiredService<AppDbContext>();
var existing = await context.AmtDevices
.FirstOrDefaultAsync(d => d.IpAddress == device.IpAddress);
if (existing != null)
{
existing.Hostname = device.Hostname;
existing.MajorVersion = device.MajorVersion;
existing.MinorVersion = device.MinorVersion;
existing.ProvisioningState = device.ProvisioningState;
existing.Description = device.Description;
existing.AmtOnline = device.AmtOnline;
existing.LastSeenAt = DateTime.UtcNow;
context.AmtDevices.Update(existing);
}
else
{
context.AmtDevices.Add(device);
}
await context.SaveChangesAsync();
}
private int ParseMajorVersion(string version)
{
var parts = version.Split('.');
return parts.Length > 0 && int.TryParse(parts[0], out var major) ? major : 0;
}
private int ParseMinorVersion(string version)
{
var parts = version.Split('.');
return parts.Length > 1 && int.TryParse(parts[1], out var minor) ? minor : 0;
}
private ProvisioningState MapProvisioningState(string state)
{
return state?.ToUpper() switch
{
"PRE" => ProvisioningState.PRE,
"IN" => ProvisioningState.IN,
"POST" => ProvisioningState.POST,
_ => ProvisioningState.UNKNOWN
};
}
private ProvisioningState MapProvisioningStateFromInt(int state)
{
return state switch
{
0 => ProvisioningState.PRE,
1 => ProvisioningState.IN,
2 => ProvisioningState.POST,
_ => ProvisioningState.UNKNOWN
};
}
}