//////////////////////////////////////////////////////////////////////////////
//
//                      INTEL CONFIDENTIAL
//       Copyright 2017 Intel Corporation All Rights Reserved.
//
// The source code contained or described herein and all documents related to
// the source code ("Material") are owned by Intel Corporation or its
// suppliers. Title to the Material remains with Intel Corporation, its
// suppliers, or licensors. The Material contains trade secrets and
// proprietary and confidential information of Intel Corporation, its
// suppliers, and licensors, and is protected by worldwide copyright and trade
// secret laws and treaty provisions. No part of the Material may be used,
// copied, reproduced, modified, published, uploaded, posted, transmitted,
// distributed, or disclosed in any way without Intel's prior express written
// permission.
//
// No license under any patent, copyright, trade secret or other intellectual
// property right is granted to or conferred upon you by disclosure or
// delivery of the Materials, either expressly, by implication, inducement,
// estoppel or otherwise. Any license under such intellectual property rights
// must be express and approved by Intel in writing.
//
// Unless otherwise agreed by Intel in writing, you may not remove or alter
// this notice or any other notice embedded in Materials by Intel or Intel's
// suppliers or licensors in any way.
//
//////////////////////////////////////////////////////////////////////////////
#include "TransportTLS.hpp"


TransportTLS::TransportTLS(const char* address, const char* port, const char* targetChainFile,
	std::shared_ptr<SSLMethod> sslMethod, std::shared_ptr<OSDependencies> deps) {
	this->SslMethod = sslMethod;
	this->TargetChainFile = targetChainFile;
	Address = address;
	Port = port;
	this->MoreData = false;
	Dependencies = std::move(deps);
}

Connection_Error TransportTLS::Connect() {
	using namespace std::placeholders; // for _1
	if (SslMethod != nullptr) {
		SslMethod->RegisterErrorHandler(std::bind(&TransportTLS::SSLErrorHandler, this, _1, _2));
		std::string addr(Address);
		std::string port(Port);
		std::string targetChainFile(TargetChainFile);
		auto error = SslMethod->Open(addr, port, targetChainFile);
		if (error) {
			return Connection_Error::Could_Not_Connect_To_BMC;
		}
	} else {
		return Connection_Error::Bad_Argument;
	}
	OSDepSocket socket = SslMethod->Getfd();
	if (!Dependencies->IsSocketValid(socket)) {
		return Connection_Error::Could_Not_Connect_To_BMC;
	}
	if (!Dependencies->SetSockOptKeepAlive(socket)) {
		return Connection_Error::Could_Not_Connect_To_BMC;
	}
	if (!Dependencies->SetSockOptNoDelay(socket)) {
		return Connection_Error::Could_Not_Connect_To_BMC;;
	}
	return Connection_Error::No_Error;
}

int TransportTLS::Send(char *send_buffer, int length) {
	return SslMethod->Write(send_buffer, length);
}

int TransportTLS::Receive(char *buffer, int length) {
	int len = SslMethod->Read(buffer, length);
	this->MoreData = false;
	if (SslMethod->Pending() > 0) 
		this->MoreData = true;
	return len;
}

bool TransportTLS::AwaitData()
{
	bool result = false;
	OSDepSocket socket =  SslMethod->Getfd();
	if (socket == -1) {
		result = false;
	}
	if (this->MoreData) {
		result = true;
	}
	else {
		if (!Dependencies->IsSocketValid(socket)) {
			return false;
		} else {
			result = Dependencies->AwaitData(socket);
			return result;
		}
	}
	return result;
}

void TransportTLS::Close() {
	if (SslMethod != nullptr)
		SslMethod->Close();
}

void TransportTLS::ForEachWarningAndError(
	unsigned int warning_notificaiton_type,
	unsigned int error_notification_type,
	std::function<void(unsigned int, std::string)> log_function) {
	for (auto iter = Messages.begin(); iter != Messages.end(); iter++) {
		// Is message a warning or and error?
		unsigned int notification_type = warning_notificaiton_type;
		if (iter->type <= 128) // TODO: define this
			notification_type = error_notification_type;
		log_function(notification_type, iter->value);
	}
	Messages.clear();
}

void TransportTLS::RegisterMessage(TransportMessage message) {
	Messages.emplace_front(message);
}

// Add to list of warnings and errors
void TransportTLS::SSLErrorHandler(SSL_Error error, std::string message) {
	if (message.empty()) {
		switch (error) {
		case SSL_Error::Chain_File_Not_Found:
			message = "Aborted due to no server certificate found or invalid server certificate.";
			break;
		case SSL_Error::Unable_To_Verify_Target:
			message = "Non-matching server certificate.";
			break;
		case SSL_Error::Verify_Target_Off:
			message = "Server verification disabled.";
			break;
		case SSL_Error::Unable_to_Contact_Target:
			message = "Unable to Contact Target.";
			break;
		case SSL_Error::Handshake_Failed:
			message = "TLS Layer was not able to contact target.";
			break;
		case SSL_Error::Passphrase_File_Not_Found:
			message = "TLS Layer unable to read passphrase file.";
		case SSL_Error::SSL_No_Error:
			break;
		case SSL_Error::General_Failure:
			message = "A 'General Failure' was reproted while creating the TLS connection.";
			break;
		case SSL_Error::General_Warning:
			message = "A 'General Warning' was reproted while creating the TLS connection.";
			break;
		}
	}
	if (!message.empty()) {
		TransportMessage _message = { (unsigned int)error, message };
		RegisterMessage(_message);
	}
}

TransportTLS::~TransportTLS() {
	this->Close();
}
