375 lines
12 KiB
C#
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

using AmtScanner.Api.Data;
using AmtScanner.Api.Models;
using Microsoft.AspNetCore.Mvc;
using Microsoft.EntityFrameworkCore;
namespace AmtScanner.Api.Controllers;
[ApiController]
[Route("api/agent")]
public class AgentController : ControllerBase
{
private readonly AppDbContext _db;
private readonly ILogger<AgentController> _logger;
private readonly IConfiguration _configuration;
public AgentController(AppDbContext db, ILogger<AgentController> logger, IConfiguration configuration)
{
_db = db;
_logger = logger;
_configuration = configuration;
}
/// <summary>
/// 接收 Agent 上报的设备信息
/// </summary>
[HttpPost("report")]
public async Task<IActionResult> Report([FromBody] AgentReportDto report)
{
// 验证 Agent Key
var agentKey = Request.Headers["X-Agent-Key"].FirstOrDefault();
var expectedKey = _configuration["Agent:Key"];
if (!string.IsNullOrEmpty(expectedKey) && agentKey != expectedKey)
{
_logger.LogWarning("Agent Key 验证失败: {Key}", agentKey);
return Unauthorized(ApiResponse<object>.Fail(401, "Agent Key 无效"));
}
if (string.IsNullOrEmpty(report.Uuid))
{
return BadRequest(ApiResponse<object>.Fail(400, "UUID 不能为空"));
}
_logger.LogInformation("收到设备上报: UUID={Uuid}, IP={Ip}, Hostname={Hostname}",
report.Uuid, report.IpAddress, report.Hostname);
try
{
// 查找或创建设备记录
var device = await _db.AgentDevices_new.FindAsync(report.Uuid);
if (device == null)
{
device = new AgentDevice
{
Uuid = report.Uuid,
CreatedAt = DateTime.UtcNow
};
_db.AgentDevices_new.Add(device);
}
// 更新设备信息
device.Hostname = report.Hostname;
device.IpAddress = report.IpAddress;
device.MacAddress = report.MacAddress;
device.SubnetMask = report.SubnetMask;
device.Gateway = report.Gateway;
device.OsName = report.OsName;
device.OsVersion = report.OsVersion;
device.OsArchitecture = report.OsArchitecture;
device.CpuName = report.CpuName;
device.TotalMemoryMB = report.TotalMemoryMB;
device.Manufacturer = report.Manufacturer;
device.Model = report.Model;
device.SerialNumber = report.SerialNumber;
device.CurrentUser = report.CurrentUser;
device.UserDomain = report.UserDomain;
device.BootTime = report.BootTime;
device.LastReportAt = DateTime.UtcNow;
device.IsOnline = true;
await _db.SaveChangesAsync();
return Ok(ApiResponse<object>.Success(null, "上报成功"));
}
catch (Exception ex)
{
_logger.LogError(ex, "保存设备信息失败");
return StatusCode(500, ApiResponse<object>.Fail(500, "服务器内部错误"));
}
}
/// <summary>
/// 接收心跳
/// </summary>
[HttpPost("heartbeat")]
public async Task<IActionResult> Heartbeat([FromBody] HeartbeatDto heartbeat)
{
if (string.IsNullOrEmpty(heartbeat.Uuid))
{
return BadRequest(ApiResponse<object>.Fail(400, "UUID 不能为空"));
}
var device = await _db.AgentDevices_new.FindAsync(heartbeat.Uuid);
if (device != null)
{
device.LastReportAt = DateTime.UtcNow;
device.IsOnline = true;
await _db.SaveChangesAsync();
}
return Ok(ApiResponse<object>.Success(null));
}
/// <summary>
/// 获取所有 Agent 设备列表
/// </summary>
[HttpGet("devices")]
public async Task<IActionResult> GetDevices([FromQuery] int page = 1, [FromQuery] int pageSize = 20, [FromQuery] string? search = null)
{
var query = _db.AgentDevices_new.AsQueryable();
if (!string.IsNullOrEmpty(search))
{
query = query.Where(d =>
d.Uuid.Contains(search) ||
d.Hostname.Contains(search) ||
d.IpAddress.Contains(search));
}
var total = await query.CountAsync();
var items = await query
.OrderByDescending(d => d.LastReportAt)
.Skip((page - 1) * pageSize)
.Take(pageSize)
.ToListAsync();
// 更新在线状态超过3分钟未上报视为离线
var threshold = DateTime.UtcNow.AddMinutes(-3);
foreach (var item in items)
{
item.IsOnline = item.LastReportAt > threshold;
}
return Ok(ApiResponse<object>.Success(new { items, total, page, pageSize }));
}
/// <summary>
/// 获取单个设备详情
/// </summary>
[HttpGet("devices/{uuid}")]
public async Task<IActionResult> GetDevice(string uuid)
{
var device = await _db.AgentDevices_new.FindAsync(uuid);
if (device == null)
{
return NotFound(ApiResponse<object>.Fail(404, "设备不存在"));
}
// 更新在线状态
device.IsOnline = device.LastReportAt > DateTime.UtcNow.AddMinutes(-3);
return Ok(ApiResponse<object>.Success(device));
}
/// <summary>
/// 删除设备
/// </summary>
[HttpDelete("devices/{uuid}")]
public async Task<IActionResult> DeleteDevice(string uuid)
{
var device = await _db.AgentDevices_new.FindAsync(uuid);
if (device == null)
{
return NotFound(ApiResponse<object>.Fail(404, "设备不存在"));
}
_db.AgentDevices_new.Remove(device);
await _db.SaveChangesAsync();
return Ok(ApiResponse<object>.Success(null, "删除成功"));
}
// 内存缓存屏幕截图(生产环境可用 Redis
private static readonly Dictionary<string, ScreenshotCache> _screenshotCache = new();
private static readonly object _cacheLock = new();
/// <summary>
/// 接收屏幕截图
/// </summary>
[HttpPost("screenshot")]
[RequestSizeLimit(10 * 1024 * 1024)] // 10MB
public async Task<IActionResult> UploadScreenshot()
{
try
{
string uuid;
byte[] screenshotData;
// 支持 multipart/form-data 和 JSON 两种格式
if (Request.ContentType?.Contains("multipart/form-data") == true)
{
var form = await Request.ReadFormAsync();
uuid = form["uuid"].ToString();
var file = form.Files["screenshot"];
if (file == null || file.Length == 0)
{
return BadRequest(ApiResponse<object>.Fail(400, "截图文件为空"));
}
using var ms = new MemoryStream();
await file.CopyToAsync(ms);
screenshotData = ms.ToArray();
}
else
{
var json = await Request.ReadFromJsonAsync<ScreenshotDto>();
if (json == null || string.IsNullOrEmpty(json.Uuid) || string.IsNullOrEmpty(json.Screenshot))
{
return BadRequest(ApiResponse<object>.Fail(400, "参数无效"));
}
uuid = json.Uuid;
screenshotData = Convert.FromBase64String(json.Screenshot);
}
if (string.IsNullOrEmpty(uuid))
{
return BadRequest(ApiResponse<object>.Fail(400, "UUID 不能为空"));
}
// 缓存截图
lock (_cacheLock)
{
_screenshotCache[uuid] = new ScreenshotCache
{
Data = screenshotData,
UpdatedAt = DateTime.UtcNow
};
// 清理过期缓存超过1分钟
var expiredKeys = _screenshotCache
.Where(kv => kv.Value.UpdatedAt < DateTime.UtcNow.AddMinutes(-1))
.Select(kv => kv.Key)
.ToList();
foreach (var key in expiredKeys)
{
_screenshotCache.Remove(key);
}
}
_logger.LogDebug("收到截图: UUID={Uuid}, Size={Size}KB", uuid, screenshotData.Length / 1024);
return Ok(ApiResponse<object>.Success(null));
}
catch (Exception ex)
{
_logger.LogError(ex, "处理截图上传失败");
return StatusCode(500, ApiResponse<object>.Fail(500, "服务器内部错误"));
}
}
/// <summary>
/// 获取设备屏幕截图
/// </summary>
[HttpGet("screenshot/{uuid}")]
public IActionResult GetScreenshot(string uuid)
{
lock (_cacheLock)
{
if (_screenshotCache.TryGetValue(uuid, out var cache))
{
return File(cache.Data, "image/jpeg");
}
}
return NotFound(ApiResponse<object>.Fail(404, "截图不存在或已过期"));
}
/// <summary>
/// 批量获取多个设备的屏幕截图Base64
/// </summary>
[HttpPost("screenshots/batch")]
public IActionResult GetScreenshotsBatch([FromBody] List<string> uuids)
{
var result = new Dictionary<string, string?>();
lock (_cacheLock)
{
foreach (var uuid in uuids)
{
if (_screenshotCache.TryGetValue(uuid, out var cache))
{
result[uuid] = Convert.ToBase64String(cache.Data);
}
else
{
result[uuid] = null;
}
}
}
return Ok(ApiResponse<object>.Success(result));
}
/// <summary>
/// 获取所有在线设备的屏幕截图列表
/// </summary>
[HttpGet("screenshots")]
public async Task<IActionResult> GetAllScreenshots()
{
var threshold = DateTime.UtcNow.AddMinutes(-3);
var onlineDevices = await _db.AgentDevices_new
.Where(d => d.LastReportAt > threshold)
.Select(d => new { d.Uuid, d.Hostname, d.IpAddress })
.ToListAsync();
var result = new List<object>();
lock (_cacheLock)
{
foreach (var device in onlineDevices)
{
var hasScreenshot = _screenshotCache.ContainsKey(device.Uuid);
result.Add(new
{
device.Uuid,
device.Hostname,
device.IpAddress,
HasScreenshot = hasScreenshot,
ScreenshotUrl = hasScreenshot ? $"/api/agent/screenshot/{device.Uuid}" : null
});
}
}
return Ok(ApiResponse<object>.Success(result));
}
}
public class AgentReportDto
{
public string Uuid { get; set; } = "";
public string Hostname { get; set; } = "";
public string IpAddress { get; set; } = "";
public string MacAddress { get; set; } = "";
public string SubnetMask { get; set; } = "";
public string Gateway { get; set; } = "";
public string OsName { get; set; } = "";
public string OsVersion { get; set; } = "";
public string OsArchitecture { get; set; } = "";
public string CpuName { get; set; } = "";
public long TotalMemoryMB { get; set; }
public string Manufacturer { get; set; } = "";
public string Model { get; set; } = "";
public string SerialNumber { get; set; } = "";
public string CurrentUser { get; set; } = "";
public string UserDomain { get; set; } = "";
public DateTime? BootTime { get; set; }
}
public class HeartbeatDto
{
public string Uuid { get; set; } = "";
}
public class ScreenshotDto
{
public string Uuid { get; set; } = "";
public string Screenshot { get; set; } = "";
}
public class ScreenshotCache
{
public byte[] Data { get; set; } = Array.Empty<byte>();
public DateTime UpdatedAt { get; set; }
}