383 lines
11 KiB
C++

//----------------------------------------------------------------------------
//
// Copyright (C) Intel Corporation, 2006 - 2007.
//
// File: SocksSupplier.h
//
// Contents:
//
// Notes:
//----------------------------------------------------------------------------
//===================================================
// INCLUDES
//===================================================
#include <ace/Message_Block.h>
#include <ace/SOCK_Stream.h>
#include <ace/INET_Addr.h>
#include "TcpSvcHandler.h"
#include "TcpSupplier.h"
#include "SOCKSv5.h"
#include "TcpConsumer.h"
#include "ChannelConsumer.h"
#include "SocksSvcHandler.h"
#include "global.h"
#include "OptionsUtils.h"
//===================================================
// SocksSupplier Implementation
//===================================================
STATUS SocksSupplier::handle_input(ACE_HANDLE h)
{
STATUS rep = STATUS_SUCCESS;
if (socks_state() == Socks_Svc_Handler::CONNECTION_FAILURE)
{
PRINT_ERROR_STATUS(rep, ACE_TEXT("SocksSupplier::handle_input - state is CONNECTION_FAILURE\n"));
return STATUS_CONNECTION_CLOSED;
}
if (isNegotiationNeeded())
{
rep = startNegotiate();
PRINT_ERROR_STATUS(rep, ACE_TEXT("Failed during SOCKS negotiation\n"));
}
else
{
rep = tunnelData();
if (rep != STATUS_SUCCESS)
{
ACE_DEBUG((MY_DEBUG
ACE_TEXT("Failed during SOCKS tunnel data. rep = %d\n"),
rep));
}
}
switch (rep)
{
case STATUS_NETWORK_ERROR:
case STATUS_FAILURE:
return rep;
default:
return STATUS_SUCCESS;
}
}
//-----------------------------------------
// handle SOCKS negotiation
//-----------------------------------------
STATUS SocksSupplier::startNegotiate(void)
{
ACE_TRACE(ACE_TEXT("SocksSupplier::startNegotiate"));
STATUS rep = STATUS_SUCCESS;
switch(socks_state())
{
case (Socks_Svc_Handler::INIT):
rep = startConnection();
break;
case (Socks_Svc_Handler::AUTHENTICATION):
rep = authenticate();
break;
case (Socks_Svc_Handler::NEGOTIATION):
rep = negotiate();
break;
default:
ACE_DEBUG((MY_DEBUG ACE_TEXT("Negotiation in illegal state\n")));
rep = STATUS_FAILURE;
break;
}
//if negotiation failed - need to close tcpConsumer as well (otherwise the tunnel will close it)
if ((consumer() != NULL) &&
((rep == STATUS_NETWORK_ERROR) || (rep == STATUS_FAILURE) ||
(rep == STATUS_BAD_REQUEST) || (rep == STATUS_AUTH_FAILURE) || (rep == STATUS_OPERATION_NOT_SUPPORTED) ||
(rep == STATUS_LOCK_FAILURE) || (rep == STATUS_QUEUE_TIMEOUT)))
{
consumer()->connectionClose();
}
return rep;
}
//-----------------------------------------
// Returns true if negotiation is needed before
// starting forwarding data
//
//-----------------------------------------
bool SocksSupplier::isNegotiationNeeded()
{
return (socks_state() < Socks_Svc_Handler::WAIT_NEGOTIATION_REP);
}
//-----------------------------------------
// handle SOCKS second negotiation message
// (create connection)
//-----------------------------------------
STATUS SocksSupplier::startConnection()
{
ACE_TRACE(ACE_TEXT("SocksSupplier::startConnection"));
SOCKSv5_ClientVersion version;
const bool* needAuthenticationPtr = getSocksNeedAuthentication();
const ACE_TString* socksDllNamePtr = getSocksDllName();
const ACE_TString* socksDllParametersPtr = getSocksDllParameters();
if (needAuthenticationPtr == NULL || socksDllNamePtr == NULL || socksDllParametersPtr == NULL) {
ACE_DEBUG((MY_DEBUG
ACE_TEXT("Failed to get SOCKS Authentication details\n")));
return STATUS_FAILURE;
}
Authentication_Param socks_auth_params(*needAuthenticationPtr,
*socksDllNamePtr,
*socksDllParametersPtr);
STATUS rep = STATUS_SUCCESS,
rep2 = STATUS_SUCCESS;
SOCKS_V5_METHOD support_method = NO_AUTH;
if (socks_auth_params._authenticate)
support_method = USER_PASS;
if (version.read(getPeer()) != STATUS_SUCCESS)
{
ACE_DEBUG((MY_WARNING ACE_TEXT("[%s] Failed to read SOCKS version message\n"),identifier()));
return STATUS_NETWORK_ERROR;
}
if (!MPSSupportSocksMethod(version))
{
support_method = NOT_ACCEPTABLE;
ACE_DEBUG((MY_WARNING ACE_TEXT("[%s] Received unsupported SOCKS connection message\n"),identifier()));
rep2 = STATUS_OPERATION_NOT_SUPPORTED;
}
if ((consumer() != NULL) && ( (rep = consumer()->sendVersionMsg(support_method)) != STATUS_SUCCESS))
{
ACE_DEBUG((MY_WARNING ACE_TEXT("[%s] Failed to send SOCKS version message\n"),identifier()));
}
return (rep2 == STATUS_SUCCESS? rep:rep2);
}
STATUS SocksSupplier::negotiate(void)
{
ACE_TRACE(ACE_TEXT("SocksSupplier::negotiate"));
SOCKSv5_ClientRequest c_request;
//read message
if (c_request.read(getPeer()) != STATUS_SUCCESS)
{
ACE_ERROR_RETURN((MY_ERROR
"[%s] %p\n",identifier(),
"Failed reading socks negotiation message"),
STATUS_NETWORK_ERROR);
}
ACE_CString host;
if (c_request.CMD != CONNECT)
{
ACE_DEBUG((MY_ERROR ACE_TEXT("[%s] Received SOCKS negotiation request with invalid command\n"),identifier()));
if (consumer() != NULL)
{
((SocksConsumer*)consumer())->openRep(Tcp_Consumer::FAILURE, SocksConsumer::COMMAND_NOT_SUPPORTED);
}
return STATUS_BAD_REQUEST;
}
else
{
// We support only domain name address type
switch(c_request.ATYP)
{
case DOMAINNAME:
host = c_request.DST_ADDR;//, c_request.ADDR_LEN);
break;
default:
// Address not supported
ACE_DEBUG((MY_ERROR
ACE_TEXT("[%s] Received SOCKS negotiation request with address type NOT SUPPOTRED\n"),identifier()));
if (consumer() != NULL)
{
((SocksConsumer*)consumer())->openRep(Tcp_Consumer::FAILURE, SocksConsumer::ADDRESS_TYPE_NOT_SUPPORTED);
}
return STATUS_BAD_REQUEST;
}
}
if (host.empty())
{
ACE_DEBUG((MY_ERROR ACE_TEXT("[%s] Received ILLEGAL SOCKS negotiation request - no host\n"),identifier()));
if (consumer() != NULL)
{
((SocksConsumer*)consumer())->openRep(Tcp_Consumer::FAILURE, SocksConsumer::SERVER_FAILURE);
}
return STATUS_BAD_REQUEST;
}
ACE_DEBUG((MY_DEBUG
ACE_TEXT ("SOCKS negotiation\n\thost = %s\n\tport = %d\n"),
host.c_str(),
c_request.DST_PORT));
//get sender address:
ACE_INET_Addr sender_addr;
getPeer().get_remote_addr(sender_addr);
char hoststr[MAX_HOST_NAME_LEN+1];
sender_addr.get_host_addr (hoststr, MAX_HOST_NAME_LEN+1);
ACE_NEW_RETURN(_channel_consummer, Channel_Consumer(this), STATUS_MALLOC_FAILURE);
((Socks_Svc_Handler*)_svc_handler)->_socks_state = Socks_Svc_Handler::WAIT_NEGOTIATION_REP;
STATUS rep = _channel_consummer->openChannel(host, c_request.DST_PORT, hoststr, sender_addr.get_port_number(), consumer());
if (rep != STATUS_SUCCESS)
{
ACE_DEBUG((MY_DEBUG ACE_TEXT("[%s] Failed to send open channel request! reason = %d\n"),identifier(),rep));
if (consumer() != NULL)
{
(((SocksConsumer*)consumer())->openRep(Tcp_Consumer::FAILURE, SocksConsumer::HOST_UNREACHABLE));
}
((Socks_Svc_Handler*)_svc_handler)->_socks_state = Socks_Svc_Handler::CONNECTION_FAILURE;
return rep;
}
return STATUS_SUCCESS;
}
//-----------------------------------------
// Handle SOCKS Authentication.
//
// NOTE:
// Enter here only when MPS configured to support SOCKS authentication.
//-----------------------------------------
STATUS SocksSupplier::authenticate()
{
ACE_TRACE(ACE_TEXT("SocksSupplier::authenticate"));
SOCKSv5_UssrPwdSubNegotiate c_request;
STATUS rep = STATUS_SUCCESS;
//read message
if (c_request.read(getPeer()) != STATUS_SUCCESS)
{
ACE_ERROR_RETURN((MY_ERROR
"[%s] %p\n",identifier(),
"Failed reading socks authentication message"),
STATUS_NETWORK_ERROR);
}
//verify user name/pwd:
SOCKS_V5_REP status = SOCKS_SUCCEEDED;
// Not currently supported in Linux
#ifdef ACE_WIN32
char error[MAX_DLL_ERR_LEN];
const bool* needAuthenticationPtr = getSocksNeedAuthentication();
const ACE_TString* socksDllNamePtr = getSocksDllName();
const ACE_TString* socksDllParametersPtr = getSocksDllParameters();
if (needAuthenticationPtr == NULL || socksDllNamePtr == NULL || socksDllParametersPtr == NULL) {
ACE_DEBUG((MY_DEBUG
ACE_TEXT("Failed to get SOCKS Authentication details\n")));
return STATUS_FAILURE;
}
Authentication_Param socks_auth_params(*needAuthenticationPtr,
*socksDllNamePtr,
*socksDllParametersPtr);
if ((_socks_auth_func != NULL) && (socks_auth_params._authenticate))
{
UINT8 authorized = _socks_auth_func (c_request.UNAME.c_str(),
c_request.PWD.c_str(),
socks_auth_params._dllParams.c_str(),
&(error[0]),
MAX_DLL_ERR_LEN );
if (authorized == 0)
{
ACE_DEBUG((MY_WARNING ACE_TEXT("[%s] SOCKS authentication failed error = %s. \n"),identifier(), error));
status = CONNECTION_NOT_ALLOWED;
rep = STATUS_AUTH_FAILURE;
}
else
{
ACE_DEBUG((MY_DEBUG ACE_TEXT("authentication SUCCESS\n")));
}
}
#endif /* ACE_WIN32 */
//send auth reply:
if (consumer() != NULL)
{
if (((SocksConsumer*)(consumer()))->sendAuthMsg(status) != STATUS_SUCCESS)
{
ACE_DEBUG((MY_ERROR ACE_TEXT("[%s] Failed to send SOCKS authentication message\n"),identifier()));
rep = STATUS_FATAL_ERROR;
}
}
return rep;
}
Socks_Svc_Handler::SOCKS_STATE SocksSupplier::socks_state()
{
return ((Socks_Svc_Handler*)_svc_handler)->_socks_state;
}
//-----------------------------------------
// returns true if MPS supports one of the methods client asked
// in 'version' message. (for more details see rfc 1929)
//
// NOTE:
// this depends on weather MPS support client Authentication
// (defined in config file)
//-----------------------------------------
bool SocksSupplier::MPSSupportSocksMethod(SOCKSv5_ClientVersion& version)
{
const bool* needAuthenticationPtr = getSocksNeedAuthentication();
const ACE_TString* socksDllNamePtr = getSocksDllName();
const ACE_TString* socksDllParametersPtr = getSocksDllParameters();
if (needAuthenticationPtr == NULL || socksDllNamePtr == NULL || socksDllParametersPtr == NULL) {
ACE_DEBUG((MY_DEBUG
ACE_TEXT("Failed to get SOCKS Authentication details\n")));
return STATUS_FAILURE;
}
Authentication_Param socks_auth_params(*needAuthenticationPtr,
*socksDllNamePtr,
*socksDllParametersPtr);
if (socks_auth_params._authenticate)
{
if (!ClientSupportMethod(version, USER_PASS))
{
//client must support user/password method - if not need to close connection
return false;
}
}
else if(!ClientSupportMethod(version, NO_AUTH))
{
return false;
}
return true;
}
//-----------------------------------------
// returns true if client's SOCKS 'version' message
// support in 'method'
//-----------------------------------------
bool SocksSupplier::ClientSupportMethod(SOCKSv5_ClientVersion& version, SOCKS_V5_METHOD method)
{
for (int i=0; i < version.NMETHODS; i++)
{
if (version.METHODS[i] == method)
return true;
}
return false;
}