diff --git a/src/os/inc/osapi-error.h b/src/os/inc/osapi-error.h index b94982993..14182f71c 100644 --- a/src/os/inc/osapi-error.h +++ b/src/os/inc/osapi-error.h @@ -87,6 +87,7 @@ typedef char os_err_name_t[OS_ERROR_NAME_LENGTH]; #define OS_ERR_OPERATION_NOT_SUPPORTED (-38) /**< @brief Requested operation not support on supplied object(s) */ #define OS_ERR_INVALID_SIZE (-40) /**< @brief Invalid Size */ #define OS_ERR_OUTPUT_TOO_LARGE (-41) /**< @brief Size of output exceeds limit */ +#define OS_ERR_INVALID_ARGUMENT (-42) /**< @brief Invalid argument value (other than ID or size) */ /* ** Defines for File System Calls diff --git a/src/os/inc/osapi-sockets.h b/src/os/inc/osapi-sockets.h index 09e3b5a75..9b2cac484 100644 --- a/src/os/inc/osapi-sockets.h +++ b/src/os/inc/osapi-sockets.h @@ -76,6 +76,16 @@ typedef enum OS_SocketType_MAX /**< @brief Maximum */ } OS_SocketType_t; +/* NOTE: The shutdown mode enums are also a bitmask, so the specific values are important here */ +/** @brief Shutdown Mode */ +typedef enum +{ + OS_SocketShutdownMode_NONE = 0, /**< @brief Reserved value, no effect */ + OS_SocketShutdownMode_SHUT_READ = 1, /**< @brief Disable future reading */ + OS_SocketShutdownMode_SHUT_WRITE = 2, /**< @brief Disable future writing */ + OS_SocketShutdownMode_SHUT_READWRITE = 3 /**< @brief Disable future reading or writing */ +} OS_SocketShutdownMode_t; + /** * @brief Storage buffer for generic network address * @@ -279,6 +289,20 @@ int32 OS_SocketBind(osal_id_t sock_id, const OS_SockAddr_t *Addr); */ int32 OS_SocketConnect(osal_id_t sock_id, const OS_SockAddr_t *Addr, int32 timeout); +/*-------------------------------------------------------------------------------------*/ +/** + * @brief Implement graceful shutdown of a stream socket + * + * This can be utilized to indicate the end of data stream without immediately closing + * the socket, giving the remote side an indication that the data transfer is complete. + * + * @param[in] sock_id The socket ID + * @param[in] Mode Whether to shutdown reading, writing, or both. + * + * @return Execution status, see @ref OSReturnCodes + */ +int32 OS_SocketShutdown(osal_id_t sock_id, OS_SocketShutdownMode_t Mode); + /*-------------------------------------------------------------------------------------*/ /** * @brief Waits for and accept the next incoming connection on the given socket diff --git a/src/os/portable/os-impl-bsd-sockets.c b/src/os/portable/os-impl-bsd-sockets.c index 7ddb5905c..c992ea78b 100644 --- a/src/os/portable/os-impl-bsd-sockets.c +++ b/src/os/portable/os-impl-bsd-sockets.c @@ -338,6 +338,49 @@ int32 OS_SocketConnect_Impl(const OS_object_token_t *token, const OS_SockAddr_t return return_code; } /* end OS_SocketConnect_Impl */ +/*---------------------------------------------------------------- + Function: OS_SocketShutdown_Impl + + Purpose: Connects the socket to a remote address. + Socket must be of the STREAM variety. + + Returns: OS_SUCCESS on success, or relevant error code + ------------------------------------------------------------------*/ +int32 OS_SocketShutdown_Impl(const OS_object_token_t *token, OS_SocketShutdownMode_t Mode) +{ + OS_impl_file_internal_record_t *conn_impl; + int32 return_code; + int how; + + conn_impl = OS_OBJECT_TABLE_GET(OS_impl_filehandle_table, *token); + + /* Note that when called via the shared layer, + * the "Mode" arg has already been checked/validated. */ + if (Mode == OS_SocketShutdownMode_SHUT_READ) + { + how = SHUT_RD; + } + else if (Mode == OS_SocketShutdownMode_SHUT_WRITE) + { + how = SHUT_WR; + } + else + { + how = SHUT_RDWR; + } + + if (shutdown(conn_impl->fd, how) == 0) + { + return_code = OS_SUCCESS; + } + else + { + return_code = OS_ERROR; + } + + return return_code; +} /* end OS_SocketShutdown_Impl */ + /*---------------------------------------------------------------- * * Function: OS_SocketAccept_Impl diff --git a/src/os/portable/os-impl-no-sockets.c b/src/os/portable/os-impl-no-sockets.c index 8a4befeda..49143764e 100644 --- a/src/os/portable/os-impl-no-sockets.c +++ b/src/os/portable/os-impl-no-sockets.c @@ -71,6 +71,16 @@ int32 OS_SocketConnect_Impl(const OS_object_token_t *token, const OS_SockAddr_t return OS_ERR_NOT_IMPLEMENTED; } +/*---------------------------------------------------------------- + * Implementation for no network configuration + * + * See prototype for argument/return detail + *-----------------------------------------------------------------*/ +int32 OS_SocketShutdown_Impl(const OS_object_token_t *token, OS_SocketShutdownMode_t Mode) +{ + return OS_ERR_NOT_IMPLEMENTED; +} + /*---------------------------------------------------------------- * Implementation for no network configuration * diff --git a/src/os/posix/inc/os-impl-sockets.h b/src/os/posix/inc/os-impl-sockets.h index adae73696..29812594b 100644 --- a/src/os/posix/inc/os-impl-sockets.h +++ b/src/os/posix/inc/os-impl-sockets.h @@ -34,6 +34,7 @@ #include #include #include +#include #define OS_NETWORK_SUPPORTS_IPV6 diff --git a/src/os/shared/inc/os-shared-sockets.h b/src/os/shared/inc/os-shared-sockets.h index 7b9d96bb0..032339fa5 100644 --- a/src/os/shared/inc/os-shared-sockets.h +++ b/src/os/shared/inc/os-shared-sockets.h @@ -84,6 +84,15 @@ int32 OS_SocketAccept_Impl(const OS_object_token_t *sock_token, const OS_object_ ------------------------------------------------------------------*/ int32 OS_SocketConnect_Impl(const OS_object_token_t *token, const OS_SockAddr_t *Addr, int32 timeout); +/*---------------------------------------------------------------- + Function: OS_SocketShutdown_Impl + + Purpose: Graceful shutdown of a stream socket + + Returns: OS_SUCCESS on success, or relevant error code + ------------------------------------------------------------------*/ +int32 OS_SocketShutdown_Impl(const OS_object_token_t *token, OS_SocketShutdownMode_t Mode); + /*---------------------------------------------------------------- Function: OS_SocketRecvFrom_Impl diff --git a/src/os/shared/src/osapi-sockets.c b/src/os/shared/src/osapi-sockets.c index a62d4e14c..5619f120f 100644 --- a/src/os/shared/src/osapi-sockets.c +++ b/src/os/shared/src/osapi-sockets.c @@ -343,6 +343,62 @@ int32 OS_SocketConnect(osal_id_t sock_id, const OS_SockAddr_t *Addr, int32 Timeo return return_code; } /* end OS_SocketConnect */ +/*---------------------------------------------------------------- + * + * Function: OS_SocketShutdown + * + * Purpose: Implemented per public OSAL API + * See description in API and header file for detail + * + *-----------------------------------------------------------------*/ +int32 OS_SocketShutdown(osal_id_t sock_id, OS_SocketShutdownMode_t Mode) +{ + OS_stream_internal_record_t *stream; + OS_object_token_t token; + int32 return_code; + + /* Confirm that "Mode" is one of the 3 acceptable values */ + BUGCHECK(Mode == OS_SocketShutdownMode_SHUT_READ || Mode == OS_SocketShutdownMode_SHUT_WRITE || + Mode == OS_SocketShutdownMode_SHUT_READWRITE, + OS_ERR_INVALID_ARGUMENT); + + return_code = OS_ObjectIdGetById(OS_LOCK_MODE_GLOBAL, LOCAL_OBJID_TYPE, sock_id, &token); + if (return_code == OS_SUCCESS) + { + stream = OS_OBJECT_TABLE_GET(OS_stream_table, token); + + if (stream->socket_domain == OS_SocketDomain_INVALID) + { + return_code = OS_ERR_INCORRECT_OBJ_TYPE; + } + else if (stream->socket_type == OS_SocketType_STREAM && (stream->stream_state & OS_STREAM_STATE_CONNECTED) == 0) + { + /* Stream socket must not be connected */ + return_code = OS_ERR_INCORRECT_OBJ_STATE; + } + else + { + return_code = OS_SocketShutdown_Impl(&token, Mode); + + if (return_code == OS_SUCCESS) + { + if (Mode & OS_SocketShutdownMode_SHUT_READ) + { + stream->stream_state &= ~OS_STREAM_STATE_READABLE; + } + if (Mode & OS_SocketShutdownMode_SHUT_WRITE) + { + stream->stream_state &= ~OS_STREAM_STATE_WRITABLE; + } + } + } + + OS_ObjectIdRelease(&token); + } + + return return_code; +} /* end OS_SocketShutdown */ + /*---------------------------------------------------------------- * * Function: OS_SocketRecvFrom diff --git a/src/tests/network-api-test/network-api-test.c b/src/tests/network-api-test/network-api-test.c index e5a02bd37..70a4bf99a 100644 --- a/src/tests/network-api-test/network-api-test.c +++ b/src/tests/network-api-test/network-api-test.c @@ -35,10 +35,19 @@ #define UT_EXIT_LOOP_MAX 100 /* - * Number of client->server connections to create. - * This tests that the server socket can accept multiple connections. + * Variations of client->server connections to create. + * This tests that the server socket can accept multiple connections, + * and the various combinations of socket shutdown/closure work as expected. */ -#define UT_STREAM_CONNECTION_COUNT 4 +enum +{ + UT_STREAM_CONNECTION_INITIAL, /* On first pass, just check basic read/writes */ + UT_STREAM_CONNECTION_REUSE_SERVER, /* Second pass is the same, confirms server socket can be re-used */ + UT_STREAM_CONNECTION_READ_SHUTDOWN, /* Third pass confirms that read shutdown works correctly */ + UT_STREAM_CONNECTION_WRITE_SHUTDOWN, /* Fourth pass confirms that write shutdown works correctly */ + UT_STREAM_CONNECTION_RDWR_SHUTDOWN, /* Fifth pass confirms that read/write shutdown works correctly */ + UT_STREAM_CONNECTION_MAX +}; osal_id_t s_task_id; osal_id_t p1_socket_id; @@ -417,59 +426,75 @@ void Server_Fn(void) char Buf_trans[8] = {0}; uint8 Buf_each_char_s[256] = {0}; int32 Status; + int32 ExpectedStatus; /* Fill the memory with a count pattern */ UtMemFill(Buf_each_char_s, sizeof(Buf_each_char_s)); - iter = 0; - while (iter < UT_STREAM_CONNECTION_COUNT) + for (iter = UT_STREAM_CONNECTION_INITIAL; iter < UT_STREAM_CONNECTION_MAX; ++iter) { - ++iter; - /* Accept incoming connections */ Status = OS_SocketAccept(s_socket_id, &connsock_id, &addr, OS_PEND); if (Status != OS_SUCCESS) { - snprintf(ServerFn_ErrorString, sizeof(ServerFn_ErrorString), "OS_SocketAccept() return code=%d", - (int)Status); + snprintf(ServerFn_ErrorString, sizeof(ServerFn_ErrorString), "OS_SocketAccept() iter=%u, return code=%d", + (unsigned int)iter, (int)Status); break; } - /* Recieve incoming data from client (should be exactly 4 bytes) */ - Status = OS_TimedRead(connsock_id, Buf_trans, sizeof(Buf_trans), 10); - if (Status != 4) + /* Recieve incoming data from client - + * should be exactly 4 bytes on most cycles, but 0 bytes on the cycle + * where write shutdown was done by client side prior to initial write. */ + if (iter == UT_STREAM_CONNECTION_RDWR_SHUTDOWN) { - snprintf(ServerFn_ErrorString, sizeof(ServerFn_ErrorString), "OS_TimedRead() return code=%d", (int)Status); - break; + ExpectedStatus = 0; } - - /* Send back to client: - * 1. uint32 value indicating number of connections so far (4 bytes) - * 2. Original value recieved above (4 bytes) - * 3. String of all possible 8-bit chars [0-255] (256 bytes) - */ - Status = OS_TimedWrite(connsock_id, &iter, sizeof(iter), 10); - if (Status != sizeof(iter)) + else { - snprintf(ServerFn_ErrorString, sizeof(ServerFn_ErrorString), "OS_TimedWrite(uint32) return code=%d", - (int)Status); - break; + ExpectedStatus = 4; } - - Status = OS_TimedWrite(connsock_id, Buf_trans, 4, 10); - if (Status != 4) + Status = OS_TimedRead(connsock_id, Buf_trans, sizeof(Buf_trans), 10); + if (Status != ExpectedStatus) { - snprintf(ServerFn_ErrorString, sizeof(ServerFn_ErrorString), "OS_TimedWrite(Buf_trans) return code=%d", - (int)Status); + snprintf(ServerFn_ErrorString, sizeof(ServerFn_ErrorString), "OS_TimedRead() iter=%u, return code=%d/%d", + (unsigned int)iter, (int)Status, (int)ExpectedStatus); break; } - Status = OS_TimedWrite(connsock_id, Buf_each_char_s, sizeof(Buf_each_char_s), 10); - if (Status != sizeof(Buf_each_char_s)) + /* + * on iterations where the client is doing a read/readwrite shutdown, it will close the socket, + * and the write calls may return -1 depending on what happens first. So skip the writes. + */ + if (iter != UT_STREAM_CONNECTION_READ_SHUTDOWN && iter != UT_STREAM_CONNECTION_RDWR_SHUTDOWN) { - snprintf(ServerFn_ErrorString, sizeof(ServerFn_ErrorString), - "OS_TimedWrite(Buf_each_char_s) return code=%d", (int)Status); - break; + /* Send back to client: + * 1. uint32 value indicating number of connections so far (4 bytes) + * 2. Original value recieved above (4 bytes) + * 3. String of all possible 8-bit chars [0-255] (256 bytes) + */ + Status = OS_TimedWrite(connsock_id, &iter, sizeof(iter), 10); + if (Status != sizeof(iter)) + { + snprintf(ServerFn_ErrorString, sizeof(ServerFn_ErrorString), + "OS_TimedWrite(uint32) iter=%u, return code=%d", (unsigned int)iter, (int)Status); + break; + } + + Status = OS_TimedWrite(connsock_id, Buf_trans, 4, 10); + if (Status != 4) + { + snprintf(ServerFn_ErrorString, sizeof(ServerFn_ErrorString), + "OS_TimedWrite(Buf_trans) iter=%u, return code=%d", (unsigned int)iter, (int)Status); + break; + } + + Status = OS_TimedWrite(connsock_id, Buf_each_char_s, sizeof(Buf_each_char_s), 10); + if (Status != sizeof(Buf_each_char_s)) + { + snprintf(ServerFn_ErrorString, sizeof(ServerFn_ErrorString), + "OS_TimedWrite(Buf_each_char_s) return code=%d", (int)Status); + break; + } } OS_close(connsock_id); @@ -566,8 +591,7 @@ void TestStreamNetworkApi(void) * Connect to a server - this is done in a loop * to confirm a server socket can be re-used for multiple clients */ - iter = 0; - while (iter < UT_STREAM_CONNECTION_COUNT) + for (iter = UT_STREAM_CONNECTION_INITIAL; iter < UT_STREAM_CONNECTION_MAX; ++iter) { /* Open a client socket */ expected = OS_SUCCESS; @@ -585,7 +609,7 @@ void TestStreamNetworkApi(void) * This is done after valid connection when the c_socket_id is valid, * but it only needs to be done once, so only do this on the first pass. */ - if (iter == 0) + if (iter == UT_STREAM_CONNECTION_INITIAL) { /* OS_TimedRead */ expected = OS_ERR_INVALID_ID; @@ -644,37 +668,91 @@ void TestStreamNetworkApi(void) * Once connection is made between * server and client, transfer data */ - ++iter; - snprintf(Buf_send_c, sizeof(Buf_send_c), "%03x", (iter + 0xabc) & 0xfff); - - /* Send data to server */ - expected = sizeof(Buf_send_c); - actual = OS_TimedWrite(c_socket_id, Buf_send_c, sizeof(Buf_send_c), 10); - UtAssert_True(actual == expected, "OS_TimedWrite() (%ld) == %ld", (long)actual, (long)expected); - - /* Recieve back data from server, first is loop count */ - expected = sizeof(loopcnt); - actual = OS_TimedRead(c_socket_id, &loopcnt, sizeof(loopcnt), 10); - UtAssert_True(actual == expected, "OS_TimedRead() (%ld) == %ld", (long)actual, (long)expected); - UtAssert_UINT32_EQ(iter, loopcnt); - - /* Recieve back data from server, next is original string */ - expected = sizeof(Buf_rcv_c); - actual = OS_TimedRead(c_socket_id, Buf_rcv_c, sizeof(Buf_rcv_c), 10); - UtAssert_True(actual == expected, "OS_TimedRead() (%ld) == %ld", (long)actual, (long)expected); - UtAssert_True(strcmp(Buf_send_c, Buf_rcv_c) == 0, "Buf_rcv_c (%s) == Buf_send_c (%s)", Buf_rcv_c, - Buf_send_c); - - /* Recieve back data from server, next is 8-bit charset */ - expected = sizeof(Buf_each_char_rcv); - actual = OS_TimedRead(c_socket_id, Buf_each_char_rcv, sizeof(Buf_each_char_rcv), 10); - UtAssert_True(actual == expected, "OS_TimedRead() (%ld) == %ld", (long)actual, (long)expected); - UtAssert_MemCmpCount(Buf_each_char_rcv, sizeof(Buf_each_char_rcv), "Verify byte count pattern"); - - /* Server should close the socket, reads will return 0 indicating EOF */ - expected = 0; - actual = OS_TimedRead(c_socket_id, Buf_rcv_c, sizeof(Buf_rcv_c), 10); - UtAssert_True(actual == expected, "OS_TimedRead() (%ld) == %ld", (long)actual, (long)expected); + snprintf(Buf_send_c, sizeof(Buf_send_c), "%03x", (unsigned int)((iter + 0xabc) & 0xfff)); + + /* + * On designated iterations, use "shutdown" to indicate this is the end of the read data + */ + if (iter == UT_STREAM_CONNECTION_READ_SHUTDOWN) + { + expected = OS_SUCCESS; + actual = OS_SocketShutdown(c_socket_id, OS_SocketShutdownMode_SHUT_READ); + UtAssert_True(actual == expected, "OS_SocketShutdown(SHUT_READ) (%ld) == %ld", (long)actual, + (long)expected); + } + + if (iter == UT_STREAM_CONNECTION_RDWR_SHUTDOWN) + { + expected = OS_SUCCESS; + actual = OS_SocketShutdown(c_socket_id, OS_SocketShutdownMode_SHUT_READWRITE); + UtAssert_True(actual == expected, "OS_SocketShutdown(SHUT_READWRITE) (%ld) == %ld", (long)actual, + (long)expected); + } + + if (iter == UT_STREAM_CONNECTION_READ_SHUTDOWN || iter == UT_STREAM_CONNECTION_RDWR_SHUTDOWN) + { + /* Attempt to read data, would block/timeout normally, but + * due to read shutdown it should immediately return instead. */ + expected = 0; + actual = OS_TimedRead(c_socket_id, Buf_rcv_c, sizeof(Buf_rcv_c), 10); + UtAssert_True(actual == expected, "OS_TimedRead() after read shutdown (%ld) == %ld", (long)actual, + (long)expected); + } + + if (iter != UT_STREAM_CONNECTION_RDWR_SHUTDOWN) + { + /* Send data to server - this should still work after read shutdown, but not after write shutdown */ + expected = sizeof(Buf_send_c); + actual = OS_TimedWrite(c_socket_id, Buf_send_c, sizeof(Buf_send_c), 10); + UtAssert_True(actual == expected, "OS_TimedWrite() (%ld) == %ld", (long)actual, (long)expected); + } + + /* On the designated iteration, use shutdown to indicate this is the end of the written data */ + if (iter == UT_STREAM_CONNECTION_WRITE_SHUTDOWN) + { + expected = OS_SUCCESS; + actual = OS_SocketShutdown(c_socket_id, OS_SocketShutdownMode_SHUT_WRITE); + UtAssert_True(actual == expected, "OS_SocketShutdown(SHUT_WRITE) (%ld) == %ld", (long)actual, + (long)expected); + } + + if (iter == UT_STREAM_CONNECTION_WRITE_SHUTDOWN || iter == UT_STREAM_CONNECTION_RDWR_SHUTDOWN) + { + /* If write shutdown worked as expected, write should return an error */ + expected = OS_ERROR; + actual = OS_TimedWrite(c_socket_id, Buf_send_c, sizeof(Buf_send_c), 10); + UtAssert_True(actual == expected, "OS_TimedWrite() after SHUT_WRITE (%ld) == %ld", (long)actual, + (long)expected); + } + + /* On iterations where read was shutdown, skip the rest (reads after shutdown are unclear, may or may not + * work) */ + if (iter != UT_STREAM_CONNECTION_READ_SHUTDOWN && iter != UT_STREAM_CONNECTION_RDWR_SHUTDOWN) + { + /* Recieve back data from server, first is loop count */ + expected = sizeof(loopcnt); + actual = OS_TimedRead(c_socket_id, &loopcnt, sizeof(loopcnt), 10); + UtAssert_True(actual == expected, "OS_TimedRead() (%ld) == %ld", (long)actual, (long)expected); + UtAssert_UINT32_EQ(iter, loopcnt); + + /* Recieve back data from server, next is original string */ + expected = sizeof(Buf_rcv_c); + actual = OS_TimedRead(c_socket_id, Buf_rcv_c, sizeof(Buf_rcv_c), 10); + UtAssert_True(actual == expected, "OS_TimedRead() (%ld) == %ld", (long)actual, (long)expected); + UtAssert_True(strcmp(Buf_send_c, Buf_rcv_c) == 0, "Buf_rcv_c (%s) == Buf_send_c (%s)", Buf_rcv_c, + Buf_send_c); + + /* Recieve back data from server, next is 8-bit charset */ + expected = sizeof(Buf_each_char_rcv); + actual = OS_TimedRead(c_socket_id, Buf_each_char_rcv, sizeof(Buf_each_char_rcv), 10); + UtAssert_True(actual == expected, "OS_TimedRead() (%ld) == %ld", (long)actual, (long)expected); + UtAssert_MemCmpCount(Buf_each_char_rcv, sizeof(Buf_each_char_rcv), "Verify byte count pattern"); + + /* Server should close the socket, reads will return 0 indicating EOF */ + expected = 0; + actual = OS_TimedRead(c_socket_id, Buf_rcv_c, sizeof(Buf_rcv_c), 10); + UtAssert_True(actual == expected, "OS_TimedRead() (%ld) == %ld", (long)actual, (long)expected); + } OS_close(c_socket_id); } diff --git a/src/unit-test-coverage/portable/src/coveragetest-bsd-sockets.c b/src/unit-test-coverage/portable/src/coveragetest-bsd-sockets.c index be2912a34..3cc10d031 100644 --- a/src/unit-test-coverage/portable/src/coveragetest-bsd-sockets.c +++ b/src/unit-test-coverage/portable/src/coveragetest-bsd-sockets.c @@ -213,6 +213,23 @@ void Test_OS_SocketConnect_Impl(void) OSAPI_TEST_FUNCTION_RC(OS_SocketConnect_Impl, (&token, &addr, 0), OS_SUCCESS); } +void Test_OS_SocketShutdown_Impl(void) +{ + OS_object_token_t token = {0}; + + /* Set up token for index 0 */ + token.obj_idx = UT_INDEX_0; + + /* Check all 3 valid modes */ + OSAPI_TEST_FUNCTION_RC(OS_SocketShutdown_Impl, (&token, OS_SocketShutdownMode_SHUT_READ), OS_SUCCESS); + OSAPI_TEST_FUNCTION_RC(OS_SocketShutdown_Impl, (&token, OS_SocketShutdownMode_SHUT_WRITE), OS_SUCCESS); + OSAPI_TEST_FUNCTION_RC(OS_SocketShutdown_Impl, (&token, OS_SocketShutdownMode_SHUT_READWRITE), OS_SUCCESS); + + /* Check OS call failure */ + UT_SetDeferredRetcode(UT_KEY(OCS_shutdown), 1, -1); + OSAPI_TEST_FUNCTION_RC(OS_SocketShutdown_Impl, (&token, OS_SocketShutdownMode_SHUT_READ), OS_ERROR); +} + void Test_OS_SocketAccept_Impl(void) { OS_object_token_t sock_token = {0}; @@ -472,6 +489,7 @@ void UtTest_Setup(void) ADD_TEST(OS_SocketOpen_Impl); ADD_TEST(OS_SocketBind_Impl); ADD_TEST(OS_SocketConnect_Impl); + ADD_TEST(OS_SocketShutdown_Impl); ADD_TEST(OS_SocketAccept_Impl); ADD_TEST(OS_SocketRecvFrom_Impl); ADD_TEST(OS_SocketSendTo_Impl); diff --git a/src/unit-test-coverage/portable/src/coveragetest-no-sockets.c b/src/unit-test-coverage/portable/src/coveragetest-no-sockets.c index 34977314c..62428afe3 100644 --- a/src/unit-test-coverage/portable/src/coveragetest-no-sockets.c +++ b/src/unit-test-coverage/portable/src/coveragetest-no-sockets.c @@ -32,6 +32,7 @@ void Test_No_Sockets(void) OSAPI_TEST_FUNCTION_RC(OS_SocketBind_Impl, (NULL, NULL), OS_ERR_NOT_IMPLEMENTED); OSAPI_TEST_FUNCTION_RC(OS_SocketConnect_Impl, (NULL, NULL, 0), OS_ERR_NOT_IMPLEMENTED); OSAPI_TEST_FUNCTION_RC(OS_SocketAccept_Impl, (NULL, NULL, NULL, 0), OS_ERR_NOT_IMPLEMENTED); + OSAPI_TEST_FUNCTION_RC(OS_SocketShutdown_Impl, (NULL, 0), OS_ERR_NOT_IMPLEMENTED); OSAPI_TEST_FUNCTION_RC(OS_SocketRecvFrom_Impl, (NULL, NULL, 0, NULL, 0), OS_ERR_NOT_IMPLEMENTED); OSAPI_TEST_FUNCTION_RC(OS_SocketSendTo_Impl, (NULL, NULL, 0, NULL), OS_ERR_NOT_IMPLEMENTED); OSAPI_TEST_FUNCTION_RC(OS_SocketGetInfo_Impl, (NULL, NULL), OS_SUCCESS); diff --git a/src/unit-test-coverage/shared/src/coveragetest-sockets.c b/src/unit-test-coverage/shared/src/coveragetest-sockets.c index ea73c12b9..2240d03b9 100644 --- a/src/unit-test-coverage/shared/src/coveragetest-sockets.c +++ b/src/unit-test-coverage/shared/src/coveragetest-sockets.c @@ -257,6 +257,90 @@ void Test_OS_SocketConnect(void) (long)actual); } +/***************************************************************************** + * + * Test case for OS_SocketShutdown() + * + *****************************************************************************/ +void Test_OS_SocketShutdown(void) +{ + /* + * Test Case For: + * int32 OS_SocketShutdown(osal_id_t sock_id, OS_SocketShutdownMode_t Mode) + */ + int32 expected = OS_SUCCESS; + int32 actual = ~OS_SUCCESS; + osal_index_t idbuf; + + idbuf = UT_INDEX_1; + OS_UT_SetupTestTargetIndex(OS_OBJECT_TYPE_OS_STREAM, idbuf); + OS_stream_table[idbuf].socket_domain = OS_SocketDomain_INET; + OS_stream_table[idbuf].socket_type = OS_SocketType_STREAM; + OS_stream_table[idbuf].stream_state = + OS_STREAM_STATE_CONNECTED | OS_STREAM_STATE_READABLE | OS_STREAM_STATE_WRITABLE; + + /* nominal */ + actual = OS_SocketShutdown(UT_OBJID_1, OS_SocketShutdownMode_SHUT_READ); + UtAssert_True(actual == expected, "OS_SocketShutdown() (%ld) == OS_SUCCESS", (long)actual); + UtAssert_True((OS_stream_table[idbuf].stream_state & OS_STREAM_STATE_READABLE) == 0, "Stream bits cleared"); + UtAssert_True((OS_stream_table[idbuf].stream_state & OS_STREAM_STATE_WRITABLE) != 0, "Stream bits unchanged"); + + OS_stream_table[idbuf].stream_state = + OS_STREAM_STATE_CONNECTED | OS_STREAM_STATE_READABLE | OS_STREAM_STATE_WRITABLE; + actual = OS_SocketShutdown(UT_OBJID_1, OS_SocketShutdownMode_SHUT_WRITE); + UtAssert_True(actual == expected, "OS_SocketShutdown() (%ld) == OS_SUCCESS", (long)actual); + UtAssert_True((OS_stream_table[idbuf].stream_state & OS_STREAM_STATE_READABLE) != 0, "Stream bits unchanged"); + UtAssert_True((OS_stream_table[idbuf].stream_state & OS_STREAM_STATE_WRITABLE) == 0, "Stream bits cleared"); + + OS_stream_table[idbuf].stream_state = + OS_STREAM_STATE_CONNECTED | OS_STREAM_STATE_READABLE | OS_STREAM_STATE_WRITABLE; + actual = OS_SocketShutdown(UT_OBJID_1, OS_SocketShutdownMode_SHUT_READWRITE); + UtAssert_True(actual == expected, "OS_SocketShutdown() (%ld) == OS_SUCCESS", (long)actual); + UtAssert_True((OS_stream_table[idbuf].stream_state & OS_STREAM_STATE_READABLE) == 0, "Stream bits cleared"); + UtAssert_True((OS_stream_table[idbuf].stream_state & OS_STREAM_STATE_WRITABLE) == 0, "Stream bits unchanged"); + + /* Invalid Argument */ + expected = OS_ERR_INVALID_ARGUMENT; + OS_stream_table[idbuf].stream_state = + OS_STREAM_STATE_CONNECTED | OS_STREAM_STATE_READABLE | OS_STREAM_STATE_WRITABLE; + actual = OS_SocketShutdown(UT_OBJID_1, OS_SocketShutdownMode_NONE); + UtAssert_True(actual == expected, "OS_SocketShutdown() (%ld) == OS_ERR_INVALID_ARGUMENT", (long)actual); + UtAssert_True((OS_stream_table[idbuf].stream_state & OS_STREAM_STATE_READABLE) != 0, "Stream bits unchanged"); + UtAssert_True((OS_stream_table[idbuf].stream_state & OS_STREAM_STATE_WRITABLE) != 0, "Stream bits unchanged"); + + /* Implementation failure */ + expected = -1234; + UT_SetDefaultReturnValue(UT_KEY(OS_SocketShutdown_Impl), expected); + actual = OS_SocketShutdown(UT_OBJID_1, OS_SocketShutdownMode_SHUT_READWRITE); + UtAssert_True(actual == expected, "OS_SocketShutdown() impl failure (%ld) == %ld", (long)actual, (long)expected); + UtAssert_True((OS_stream_table[idbuf].stream_state & OS_STREAM_STATE_READABLE) != 0, "Stream bits unchanged"); + UtAssert_True((OS_stream_table[idbuf].stream_state & OS_STREAM_STATE_WRITABLE) != 0, "Stream bits unchanged"); + UT_ResetState(UT_KEY(OS_SocketShutdown_Impl)); + + /* Invalid ID */ + expected = OS_ERR_INVALID_ID; + UT_SetDefaultReturnValue(UT_KEY(OS_ObjectIdGetById), expected); + actual = OS_SocketShutdown(UT_OBJID_1, OS_SocketShutdownMode_SHUT_READWRITE); + UtAssert_True(actual == expected, "OS_SocketShutdown() invalid ID (%ld) == OS_ERR_INVALID_ID", (long)actual); + UtAssert_True((OS_stream_table[idbuf].stream_state & OS_STREAM_STATE_READABLE) != 0, "Stream bits unchanged"); + UtAssert_True((OS_stream_table[idbuf].stream_state & OS_STREAM_STATE_WRITABLE) != 0, "Stream bits unchanged"); + UT_ResetState(UT_KEY(OS_ObjectIdGetById)); + + /* Unconnected socket */ + expected = OS_ERR_INCORRECT_OBJ_STATE; + OS_stream_table[idbuf].stream_state = 0; + actual = OS_SocketShutdown(UT_OBJID_1, OS_SocketShutdownMode_SHUT_READWRITE); + UtAssert_True(actual == expected, "OS_SocketShutdown() unconnected (%ld) == OS_ERR_INCORRECT_OBJ_STATE", + (long)actual); + + /* Invalid socket type */ + expected = OS_ERR_INCORRECT_OBJ_TYPE; + OS_stream_table[idbuf].socket_domain = OS_SocketDomain_INVALID; + actual = OS_SocketShutdown(UT_OBJID_1, OS_SocketShutdownMode_SHUT_READWRITE); + UtAssert_True(actual == expected, "OS_SocketShutdown() unconnected (%ld) == OS_ERR_INCORRECT_OBJ_TYPE", + (long)actual); +} + /***************************************************************************** * * Test case for OS_SocketRecvFrom() @@ -496,6 +580,7 @@ void UtTest_Setup(void) ADD_TEST(OS_SocketConnect); ADD_TEST(OS_SocketRecvFrom); ADD_TEST(OS_SocketSendTo); + ADD_TEST(OS_SocketShutdown); ADD_TEST(OS_SocketGetIdByName); ADD_TEST(OS_SocketGetInfo); ADD_TEST(OS_CreateSocketName); diff --git a/src/unit-test-coverage/ut-stubs/inc/OCS_sys_socket.h b/src/unit-test-coverage/ut-stubs/inc/OCS_sys_socket.h index 49cb05f9d..a8f311683 100644 --- a/src/unit-test-coverage/ut-stubs/inc/OCS_sys_socket.h +++ b/src/unit-test-coverage/ut-stubs/inc/OCS_sys_socket.h @@ -73,7 +73,10 @@ enum OCS_SOL_SOCKET, OCS_SO_REUSEADDR, OCS_SO_ERROR, - OCS_MSG_DONTWAIT + OCS_MSG_DONTWAIT, + OCS_SHUT_WR, + OCS_SHUT_RD, + OCS_SHUT_RDWR }; /* ----------------------------------------- */ @@ -90,6 +93,7 @@ extern OCS_ssize_t OCS_recvfrom(int fd, void *buf, size_t n, int flags, struct O extern OCS_ssize_t OCS_sendto(int fd, const void *buf, size_t n, int flags, const struct OCS_sockaddr *addr, OCS_socklen_t addr_len); extern int OCS_setsockopt(int fd, int level, int optname, const void *optval, OCS_socklen_t optlen); +extern int OCS_shutdown(int fd, int how); extern int OCS_socket(int domain, int type, int protocol); #endif /* OCS_SYS_SOCKET_H */ diff --git a/src/unit-test-coverage/ut-stubs/override_inc/sys/socket.h b/src/unit-test-coverage/ut-stubs/override_inc/sys/socket.h index 5d229e209..a00f48883 100644 --- a/src/unit-test-coverage/ut-stubs/override_inc/sys/socket.h +++ b/src/unit-test-coverage/ut-stubs/override_inc/sys/socket.h @@ -46,6 +46,7 @@ #define recvfrom OCS_recvfrom #define sendto OCS_sendto #define setsockopt OCS_setsockopt +#define shutdown OCS_shutdown #define socket OCS_socket #define EINPROGRESS OCS_EINPROGRESS @@ -60,5 +61,8 @@ #define SO_REUSEADDR OCS_SO_REUSEADDR #define SO_ERROR OCS_SO_ERROR #define MSG_DONTWAIT OCS_MSG_DONTWAIT +#define SHUT_WR OCS_SHUT_WR +#define SHUT_RD OCS_SHUT_RD +#define SHUT_RDWR OCS_SHUT_RDWR #endif /* OVERRIDE_SYS_SOCKET_H */ diff --git a/src/unit-test-coverage/ut-stubs/src/os-shared-sockets-impl-stubs.c b/src/unit-test-coverage/ut-stubs/src/os-shared-sockets-impl-stubs.c index 22d8af09c..227fcde60 100644 --- a/src/unit-test-coverage/ut-stubs/src/os-shared-sockets-impl-stubs.c +++ b/src/unit-test-coverage/ut-stubs/src/os-shared-sockets-impl-stubs.c @@ -241,3 +241,20 @@ int32 OS_SocketSendTo_Impl(const OS_object_token_t *token, const void *buffer, s return UT_GenStub_GetReturnValue(OS_SocketSendTo_Impl, int32); } + +/* + * ---------------------------------------------------- + * Generated stub function for OS_SocketShutdown_Impl() + * ---------------------------------------------------- + */ +int32 OS_SocketShutdown_Impl(const OS_object_token_t *token, OS_SocketShutdownMode_t Mode) +{ + UT_GenStub_SetupReturnBuffer(OS_SocketShutdown_Impl, int32); + + UT_GenStub_AddParam(OS_SocketShutdown_Impl, const OS_object_token_t *, token); + UT_GenStub_AddParam(OS_SocketShutdown_Impl, OS_SocketShutdownMode_t, Mode); + + UT_GenStub_Execute(OS_SocketShutdown_Impl, Basic, NULL); + + return UT_GenStub_GetReturnValue(OS_SocketShutdown_Impl, int32); +} diff --git a/src/unit-test-coverage/ut-stubs/src/sys-socket-stubs.c b/src/unit-test-coverage/ut-stubs/src/sys-socket-stubs.c index 4d71e50e8..a3c6c112b 100644 --- a/src/unit-test-coverage/ut-stubs/src/sys-socket-stubs.c +++ b/src/unit-test-coverage/ut-stubs/src/sys-socket-stubs.c @@ -107,6 +107,14 @@ int OCS_setsockopt(int fd, int level, int optname, const void *optval, OCS_sockl return UT_DEFAULT_IMPL(OCS_setsockopt); } +int OCS_shutdown(int fd, int how) +{ + UT_Stub_RegisterContextGenericArg(UT_KEY(OCS_shutdown), fd); + UT_Stub_RegisterContextGenericArg(UT_KEY(OCS_shutdown), how); + + return UT_DEFAULT_IMPL(OCS_shutdown); +} + int OCS_socket(int domain, int type, int protocol) { UT_Stub_RegisterContextGenericArg(UT_KEY(OCS_socket), domain); diff --git a/src/ut-stubs/osapi-sockets-stubs.c b/src/ut-stubs/osapi-sockets-stubs.c index c515f8bdb..3449afd47 100644 --- a/src/ut-stubs/osapi-sockets-stubs.c +++ b/src/ut-stubs/osapi-sockets-stubs.c @@ -267,3 +267,20 @@ int32 OS_SocketSendTo(osal_id_t sock_id, const void *buffer, size_t buflen, cons return UT_GenStub_GetReturnValue(OS_SocketSendTo, int32); } + +/* + * ---------------------------------------------------- + * Generated stub function for OS_SocketShutdown() + * ---------------------------------------------------- + */ +int32 OS_SocketShutdown(osal_id_t sock_id, OS_SocketShutdownMode_t Mode) +{ + UT_GenStub_SetupReturnBuffer(OS_SocketShutdown, int32); + + UT_GenStub_AddParam(OS_SocketShutdown, osal_id_t, sock_id); + UT_GenStub_AddParam(OS_SocketShutdown, OS_SocketShutdownMode_t, Mode); + + UT_GenStub_Execute(OS_SocketShutdown, Basic, NULL); + + return UT_GenStub_GetReturnValue(OS_SocketShutdown, int32); +}