From b546539ab8bb81af21b21cc9206603bde5a4cde9 Mon Sep 17 00:00:00 2001 From: Matthew Bloch Date: Sat, 9 Jun 2012 02:25:12 +0100 Subject: [PATCH] Rewrote error & log functions to be more general, use longjmp to get out of trouble and into predictable cleanup functions (one for each of serve, client & control contexts). We use 'fatal' to mean 'kill the thread' and 'error' to mean 'don't kill the thread', assuming some recovery action, except I don't use error anywhere yet. --- src/client.c | 41 ++++++++++++------ src/control.c | 38 +++++++++++----- src/flexnbd.c | 61 +++++++++++++------------- src/ioutil.h | 4 +- src/readwrite.c | 34 +++++++-------- src/remote.c | 6 +-- src/self_pipe.c | 3 +- src/serve.c | 68 +++++++++++++++++------------ src/serve.h | 2 + src/util.c | 65 ++++++++-------------------- src/util.h | 108 ++++++++++++++++++++++++++++++++++++++-------- tests/check_acl.c | 3 +- 12 files changed, 259 insertions(+), 174 deletions(-) diff --git a/src/client.c b/src/client.c index 092689f..bb5c10f 100644 --- a/src/client.c +++ b/src/client.c @@ -97,7 +97,7 @@ void write_not_zeroes(struct client* client, off64_t from, int len) fprintf(stderr, "\n"); } - #define DO_READ(dst, len) CLIENT_ERROR_ON_FAILURE( \ + #define DO_READ(dst, len) FATAL_IF_NEGATIVE( \ readloop( \ client->socket, \ (dst), \ @@ -168,7 +168,7 @@ int client_read_request( struct client * client , struct nbd_request *out_reques FD_ZERO(&fds); FD_SET(client->socket, &fds); self_pipe_fd_set( client->stop_signal, &fds ); - CLIENT_ERROR_ON_FAILURE(select(FD_SETSIZE, &fds, NULL, NULL, NULL), + FATAL_IF_NEGATIVE(select(FD_SETSIZE, &fds, NULL, NULL, NULL), "select() failed"); if ( self_pipe_fd_isset( client->stop_signal, &fds ) ) @@ -180,7 +180,7 @@ int client_read_request( struct client * client , struct nbd_request *out_reques return 0; /* neat point to close the socket */ } else { - CLIENT_ERROR_ON_FAILURE(-1, "Error reading request"); + FATAL_IF_NEGATIVE(-1, "Error reading request"); } } @@ -223,7 +223,7 @@ void client_write_init( struct client * client, uint64_t size ) nbd_h2r_init( &init, &init_raw ); - CLIENT_ERROR_ON_FAILURE( + FATAL_IF_NEGATIVE( writeloop(client->socket, &init_raw, sizeof(init_raw)), "Couldn't send hello" ); @@ -239,7 +239,7 @@ int client_request_needs_reply( struct client * client, struct nbd_request reque debug("request type %d", request.type); if (request.magic != REQUEST_MAGIC) - CLIENT_ERROR("Bad magic %08x", request.magic); + fatal("Bad magic %08x", request.magic); switch (request.type) { @@ -265,7 +265,7 @@ int client_request_needs_reply( struct client * client, struct nbd_request reque return 0; default: - CLIENT_ERROR("Unknown request %08x", request.type); + fatal("Unknown request %08x", request.type); } return 1; } @@ -279,7 +279,7 @@ void client_reply_to_read( struct client* client, struct nbd_request request ) client_write_reply( client, &request, 0); offset = request.from; - CLIENT_ERROR_ON_FAILURE( + FATAL_IF_NEGATIVE( sendfileloop( client->socket, client->fileno, @@ -298,7 +298,7 @@ void client_reply_to_write( struct client* client, struct nbd_request request ) write_not_zeroes( client, request.from, request.len ); } else { - CLIENT_ERROR_ON_FAILURE( + FATAL_IF_NEGATIVE( readloop( client->socket, client->mapped + request.from, @@ -315,7 +315,7 @@ void client_reply_to_write( struct client* client, struct nbd_request request ) uint64_t from_rounded = request.from & (!0xfff); uint64_t len_rounded = request.len + (request.from - from_rounded); - CLIENT_ERROR_ON_FAILURE( + FATAL_IF_NEGATIVE( msync( client->mapped + from_rounded, len_rounded, @@ -371,12 +371,26 @@ void client_send_hello(struct client* client) client_write_init( client, client->serve->size ); } +void client_cleanup(struct client* client, int fatal) +{ + info("client cleanup"); + + if (client->socket) + close(client->socket); + if (client->mapped) + munmap(client->mapped, client->serve->size); + if (client->fileno) + close(client->fileno); +} + void* client_serve(void* client_uncast) { struct client* client = (struct client*) client_uncast; + error_set_handler((cleanup_handler*) client_cleanup, client); + //client_open_file(client); - CLIENT_ERROR_ON_FAILURE( + FATAL_IF_NEGATIVE( open_and_mmap( client->serve->filename, &client->fileno, @@ -390,14 +404,13 @@ void* client_serve(void* client_uncast) while (client_serve_request(client) == 0) ; - CLIENT_ERROR_ON_FAILURE( + FATAL_IF_NEGATIVE( close(client->socket), "Couldn't close socket %d", client->socket ); - - close(client->fileno); - munmap(client->mapped, client->serve->size); + + client_cleanup(client, 0); return NULL; } diff --git a/src/control.c b/src/control.c index 67e7c30..d77e137 100644 --- a/src/control.c +++ b/src/control.c @@ -115,6 +115,13 @@ void* mirror_runner(void* serve_params_uncast) written += run; } current += run; + + if (serve->mirror->signal_abandon) { + if (pass == last_pass) + server_unlock_io( serve ); + close(serve->mirror->client); + goto abandon_mirror; + } } /* if we've not written anything */ @@ -122,6 +129,7 @@ void* mirror_runner(void* serve_params_uncast) pass = last_pass; } + /* a successful finish ends here */ switch (serve->mirror->action_at_finish) { case ACTION_PROXY: @@ -131,18 +139,19 @@ void* mirror_runner(void* serve_params_uncast) break; case ACTION_EXIT: debug("exit!"); + close(serve->mirror->client); serve_signal_close( serve ); /* fall through */ case ACTION_NOTHING: debug("nothing!"); close(serve->mirror->client); } + server_unlock_io( serve ); +abandon_mirror: free(serve->mirror->dirty_map); free(serve->mirror); serve->mirror = NULL; /* and we're gone */ - - server_unlock_io( serve ); return NULL; } @@ -226,7 +235,7 @@ int control_mirror(struct control_params* client, int linesc, char** lines) mirror->max_bytes_per_second = max_bytes_per_second; mirror->action_at_finish = action_at_finish; - CLIENT_ERROR_ON_FAILURE( + FATAL_IF_NEGATIVE( open_and_mmap( client->serve->filename, &map_fd, @@ -242,7 +251,7 @@ int control_mirror(struct control_params* client, int linesc, char** lines) client->serve->mirror = mirror; - CLIENT_ERROR_ON_FAILURE( /* FIXME should free mirror on error */ + FATAL_IF_NEGATIVE( /* FIXME should free mirror on error */ pthread_create( &mirror->thread, NULL, @@ -287,6 +296,13 @@ int control_status(struct control_params* client, int linesc, char** lines) return 0; } +void control_cleanup(struct control_params* client, int fatal) +{ + if (client->socket) + close(client->socket); + free(client); +} + /** Master command parser for control socket connections, delegates quickly */ void* control_serve(void* client_uncast) { @@ -294,6 +310,8 @@ void* control_serve(void* client_uncast) char **lines = NULL; int finished=0; + error_set_handler((cleanup_handler*) control_cleanup, client); + while (!finished) { int i, linesc; linesc = read_lines_until_blankline(client->socket, 256, &lines); @@ -326,8 +344,8 @@ void* control_serve(void* client_uncast) free(lines); } - close(client->socket); - free(client); + control_cleanup(client, 0); + return NULL; } @@ -340,7 +358,7 @@ void accept_control_connection(struct server* params, int client_fd, union mysoc control_params->socket = client_fd; control_params->serve = params; - SERVER_ERROR_ON_FAILURE( + FATAL_IF_NEGATIVE( pthread_create( &control_thread, NULL, @@ -359,7 +377,7 @@ void serve_open_control_socket(struct server* params) return; params->control_fd = socket(AF_UNIX, SOCK_STREAM, 0); - SERVER_ERROR_ON_FAILURE(params->control_fd , + FATAL_IF_NEGATIVE(params->control_fd , "Couldn't create control socket"); memset(&bind_address, 0, sizeof(bind_address)); @@ -368,13 +386,13 @@ void serve_open_control_socket(struct server* params) unlink(params->control_socket_name); /* ignore failure */ - SERVER_ERROR_ON_FAILURE( + FATAL_IF_NEGATIVE( bind(params->control_fd , &bind_address, sizeof(bind_address)), "Couldn't bind control socket to %s", params->control_socket_name ); - SERVER_ERROR_ON_FAILURE( + FATAL_IF_NEGATIVE( listen(params->control_fd , 5), "Couldn't listen on control socket" ); diff --git a/src/flexnbd.c b/src/flexnbd.c index b334594..cbb4f89 100644 --- a/src/flexnbd.c +++ b/src/flexnbd.c @@ -57,16 +57,16 @@ void params_serve( { out->tcp_backlog = 10; /* does this need to be settable? */ - if (s_ip_address == NULL) - SERVER_ERROR("No IP address supplied"); - if (s_port == NULL) - SERVER_ERROR("No port number supplied"); - if (s_file == NULL) - SERVER_ERROR("No filename supplied"); + FATAL_IF_NULL(s_ip_address, "No IP address supplied"); + FATAL_IF_NULL(s_port, "No port number supplied"); + FATAL_IF_NULL(s_file, "No filename supplied"); - if (parse_ip_to_sockaddr(&out->bind_to.generic, s_ip_address) == 0) - SERVER_ERROR("Couldn't parse server address '%s' (use 0 if " - "you want to bind to all IPs)", s_ip_address); + FATAL_IF_ZERO( + parse_ip_to_sockaddr(&out->bind_to.generic, s_ip_address), + "Couldn't parse server address '%s' (use 0 if " + "you want to bind to all IPs)", + s_ip_address + ); /* control_socket_name is optional. It just won't get created if * we pass NULL. */ @@ -74,11 +74,11 @@ void params_serve( out->acl = acl_create( acl_entries, s_acl_entries, default_deny ); if (out->acl && out->acl->len != acl_entries) - SERVER_ERROR("Bad ACL entry '%s'", s_acl_entries[out->acl->len]); + fatal("Bad ACL entry '%s'", s_acl_entries[out->acl->len]); out->bind_to.v4.sin_port = atoi(s_port); if (out->bind_to.v4.sin_port < 0 || out->bind_to.v4.sin_port > 65535) - SERVER_ERROR("Port number must be >= 0 and <= 65535"); + fatal("Port number must be >= 0 and <= 65535"); out->bind_to.v4.sin_port = htobe16(out->bind_to.v4.sin_port); out->filename = s_file; @@ -111,26 +111,24 @@ void params_readwrite( char* s_length_or_filename ) { - if (s_ip_address == NULL) - SERVER_ERROR("No IP address supplied"); - if (s_port == NULL) - SERVER_ERROR("No port number supplied"); - if (s_from == NULL) - SERVER_ERROR("No from supplied"); - if (s_length_or_filename == NULL) - SERVER_ERROR("No length supplied"); + FATAL_IF_NULL(s_ip_address, "No IP address supplied"); + FATAL_IF_NULL(s_port, "No port number supplied"); + FATAL_IF_NULL(s_from, "No from supplied"); + FATAL_IF_NULL(s_length_or_filename, "No length supplied"); - if (parse_ip_to_sockaddr(&out->connect_to.generic, s_ip_address) == 0) - SERVER_ERROR("Couldn't parse connection address '%s'", - s_ip_address); + FATAL_IF_ZERO( + parse_ip_to_sockaddr(&out->connect_to.generic, s_ip_address), + "Couldn't parse connection address '%s'", + s_ip_address + ); if (s_bind_address != NULL && parse_ip_to_sockaddr(&out->connect_from.generic, s_bind_address) == 0) - SERVER_ERROR("Couldn't parse bind address '%s'", s_bind_address); + fatal("Couldn't parse bind address '%s'", s_bind_address); /* FIXME: duplicated from above */ out->connect_to.v4.sin_port = atoi(s_port); if (out->connect_to.v4.sin_port < 0 || out->connect_to.v4.sin_port > 65535) - SERVER_ERROR("Port number must be >= 0 and <= 65535"); + fatal("Port number must be >= 0 and <= 65535"); out->connect_to.v4.sin_port = htobe16(out->connect_to.v4.sin_port); out->from = atol(s_from); @@ -143,12 +141,12 @@ void params_readwrite( else { out->data_fd = open( s_length_or_filename, O_RDONLY); - SERVER_ERROR_ON_FAILURE(out->data_fd, + FATAL_IF_NEGATIVE(out->data_fd, "Couldn't open %s", s_length_or_filename); out->len = lseek64(out->data_fd, 0, SEEK_END); - SERVER_ERROR_ON_FAILURE(out->len, + FATAL_IF_NEGATIVE(out->len, "Couldn't find length of %s", s_length_or_filename); - SERVER_ERROR_ON_FAILURE( + FATAL_IF_NEGATIVE( lseek64(out->data_fd, 0, SEEK_SET), "Couldn't rewind %s", s_length_or_filename ); @@ -189,7 +187,7 @@ void read_serve_param( int c, char **ip_addr, char **ip_port, char **file, char *default_deny = 1; break; case 'v': - set_debug(1); + log_level = 0; break; default: exit_err( serve_help_text ); @@ -221,7 +219,7 @@ void read_readwrite_param( int c, char **ip_addr, char **ip_port, char **bind_ad *bind_addr = optarg; break; case 'v': - set_debug(1); + log_level = 0; break; default: exit_err( read_help_text ); @@ -240,7 +238,7 @@ void read_sock_param( int c, char **sock, char *help_text ) *sock = optarg; break; case 'v': - set_debug(1); + log_level = 0; break; default: exit_err( help_text ); @@ -272,7 +270,7 @@ void read_mirror_param( int c, char **sock, char **ip_addr, char **ip_port, char case 'b': *bind_addr = optarg; case 'v': - set_debug(1); + log_level = 0; break; default: exit_err( mirror_help_text ); @@ -538,7 +536,6 @@ int main(int argc, char** argv) { signal(SIGPIPE, SIG_IGN); /* calls to splice() unhelpfully throw this */ error_init(); - set_debug(0); if (argc < 2) { exit_err( help_help_text ); diff --git a/src/ioutil.h b/src/ioutil.h index 30fe66c..3f4b727 100644 --- a/src/ioutil.h +++ b/src/ioutil.h @@ -1,8 +1,8 @@ #ifndef __IOUTIL_H #define __IOUTIL_H - #include "serve.h" +struct bitset_mapping; /* don't need whole of bitset.h here */ /** Returns a bit field representing which blocks are allocated in file * descriptor ''fd''. You must supply the size, and the resolution at which @@ -10,7 +10,7 @@ * allocated blocks at a finer resolution than you've asked for, any block * or part block will count as "allocated" with the corresponding bit set. */ -char* build_allocation_map(int fd, off64_t size, int resolution); +struct bitset_mapping* build_allocation_map(int fd, off64_t size, int resolution); /** Repeat a write() operation that succeeds partially until ''size'' bytes * are written, or an error is returned, when it returns -1 as usual. diff --git a/src/readwrite.c b/src/readwrite.c index aca8648..c427742 100644 --- a/src/readwrite.c +++ b/src/readwrite.c @@ -10,15 +10,15 @@ int socket_connect(struct sockaddr* to, struct sockaddr* from) { int fd = socket(to->sa_family == AF_INET ? PF_INET : PF_INET6, SOCK_STREAM, 0); - SERVER_ERROR_ON_FAILURE(fd, "Couldn't create client socket"); + FATAL_IF_NEGATIVE(fd, "Couldn't create client socket"); if (NULL != from) - SERVER_ERROR_ON_FAILURE( + FATAL_IF_NEGATIVE( bind(fd, from, sizeof(struct sockaddr_in6)), "bind() failed" ); - SERVER_ERROR_ON_FAILURE( + FATAL_IF_NEGATIVE( connect(fd, to, sizeof(struct sockaddr_in6)),"connect failed" ); return fd; @@ -27,12 +27,12 @@ int socket_connect(struct sockaddr* to, struct sockaddr* from) off64_t socket_nbd_read_hello(int fd) { struct nbd_init init; - SERVER_ERROR_ON_FAILURE(readloop(fd, &init, sizeof(init)), + FATAL_IF_NEGATIVE(readloop(fd, &init, sizeof(init)), "Couldn't read init"); if (strncmp(init.passwd, INIT_PASSWD, 8) != 0) - SERVER_ERROR("wrong passwd"); + fatal("wrong passwd"); if (be64toh(init.magic) != INIT_MAGIC) - SERVER_ERROR("wrong magic (%x)", be64toh(init.magic)); + fatal("wrong magic (%x)", be64toh(init.magic)); return be64toh(init.size); } @@ -48,14 +48,14 @@ void fill_request(struct nbd_request *request, int type, int from, int len) void read_reply(int fd, struct nbd_request *request, struct nbd_reply *reply) { - SERVER_ERROR_ON_FAILURE(readloop(fd, reply, sizeof(*reply)), + FATAL_IF_NEGATIVE(readloop(fd, reply, sizeof(*reply)), "Couldn't read reply"); if (be32toh(reply->magic) != REPLY_MAGIC) - SERVER_ERROR("Reply magic incorrect (%p)", be32toh(reply->magic)); + fatal("Reply magic incorrect (%p)", be32toh(reply->magic)); if (be32toh(reply->error) != 0) - SERVER_ERROR("Server replied with error %d", be32toh(reply->error)); + fatal("Server replied with error %d", be32toh(reply->error)); if (strncmp(request->handle, reply->handle, 8) != 0) - SERVER_ERROR("Did not reply with correct handle"); + fatal("Did not reply with correct handle"); } void socket_nbd_read(int fd, off64_t from, int len, int out_fd, void* out_buf) @@ -64,16 +64,16 @@ void socket_nbd_read(int fd, off64_t from, int len, int out_fd, void* out_buf) struct nbd_reply reply; fill_request(&request, REQUEST_READ, from, len); - SERVER_ERROR_ON_FAILURE(writeloop(fd, &request, sizeof(request)), + FATAL_IF_NEGATIVE(writeloop(fd, &request, sizeof(request)), "Couldn't write request"); read_reply(fd, &request, &reply); if (out_buf) { - SERVER_ERROR_ON_FAILURE(readloop(fd, out_buf, len), + FATAL_IF_NEGATIVE(readloop(fd, out_buf, len), "Read failed"); } else { - SERVER_ERROR_ON_FAILURE( + FATAL_IF_NEGATIVE( splice_via_pipe_loop(fd, out_fd, len), "Splice failed" ); @@ -86,15 +86,15 @@ void socket_nbd_write(int fd, off64_t from, int len, int in_fd, void* in_buf) struct nbd_reply reply; fill_request(&request, REQUEST_WRITE, from, len); - SERVER_ERROR_ON_FAILURE(writeloop(fd, &request, sizeof(request)), + FATAL_IF_NEGATIVE(writeloop(fd, &request, sizeof(request)), "Couldn't write request"); if (in_buf) { - SERVER_ERROR_ON_FAILURE(writeloop(fd, in_buf, len), + FATAL_IF_NEGATIVE(writeloop(fd, in_buf, len), "Write failed"); } else { - SERVER_ERROR_ON_FAILURE( + FATAL_IF_NEGATIVE( splice_via_pipe_loop(in_fd, fd, len), "Splice failed" ); @@ -106,7 +106,7 @@ void socket_nbd_write(int fd, off64_t from, int len, int in_fd, void* in_buf) #define CHECK_RANGE(error_type) { \ off64_t size = socket_nbd_read_hello(params->client); \ if (params->from < 0 || (params->from + params->len) > size) \ - SERVER_ERROR(error_type \ + fatal(error_type \ " request %d+%d is out of range given size %d", \ params->from, params->len, size\ ); \ diff --git a/src/remote.c b/src/remote.c index 8e56def..29d805b 100644 --- a/src/remote.c +++ b/src/remote.c @@ -17,12 +17,12 @@ void do_remote_command(char* command, char* socket_name, int argc, char** argv) memset(&address, 0, sizeof(address)); - SERVER_ERROR_ON_FAILURE(remote, "Couldn't create client socket"); + FATAL_IF_NEGATIVE(remote, "Couldn't create client socket"); address.sun_family = AF_UNIX; strncpy(address.sun_path, socket_name, sizeof(address.sun_path)); - SERVER_ERROR_ON_FAILURE( + FATAL_IF_NEGATIVE( connect(remote, (struct sockaddr*) &address, sizeof(address)), "Couldn't connect to %s", socket_name ); @@ -35,7 +35,7 @@ void do_remote_command(char* command, char* socket_name, int argc, char** argv) } write(remote, &newline, 1); - SERVER_ERROR_ON_FAILURE( + FATAL_IF_NEGATIVE( read_until_newline(remote, response, max_response), "Couldn't read response from %s", socket_name ); diff --git a/src/self_pipe.c b/src/self_pipe.c index 71861fe..107f8f5 100644 --- a/src/self_pipe.c +++ b/src/self_pipe.c @@ -35,8 +35,7 @@ void self_pipe_server_error( int err, char *msg ) strerror_r( err, errbuf, 1024 ); - debug(msg); - SERVER_ERROR( "%s\t%s", msg, errbuf ); + fatal( "%s\t%s", msg, errbuf ); } /** diff --git a/src/serve.c b/src/serve.c index b6e5000..8b846e9 100644 --- a/src/serve.c +++ b/src/serve.c @@ -46,7 +46,7 @@ int server_lock_io( struct server * serve) { NULLCHECK( serve ); - SERVER_ERROR_ON_FAILURE( + FATAL_IF_NEGATIVE( pthread_mutex_lock(&serve->l_io), "Problem with I/O lock" ); @@ -59,7 +59,7 @@ void server_unlock_io( struct server* serve ) { NULLCHECK( serve ); - SERVER_ERROR_ON_FAILURE( + FATAL_IF_NEGATIVE( pthread_mutex_unlock(&serve->l_io), "Problem with I/O unlock" ); @@ -75,26 +75,26 @@ void serve_open_server_socket(struct server* params) params->server_fd= socket(params->bind_to.generic.sa_family == AF_INET ? PF_INET : PF_INET6, SOCK_STREAM, 0); - SERVER_ERROR_ON_FAILURE(params->server_fd, + FATAL_IF_NEGATIVE(params->server_fd, "Couldn't create server socket"); - SERVER_ERROR_ON_FAILURE( + FATAL_IF_NEGATIVE( setsockopt(params->server_fd, SOL_SOCKET, SO_REUSEADDR, &optval, sizeof(optval)), "Couldn't set SO_REUSEADDR" ); - SERVER_ERROR_ON_FAILURE( + FATAL_IF_NEGATIVE( setsockopt(params->server_fd, IPPROTO_TCP, TCP_NODELAY, &optval, sizeof(optval)), "Couldn't set TCP_NODELAY" ); - SERVER_ERROR_ON_FAILURE( + FATAL_IF_NEGATIVE( bind(params->server_fd, ¶ms->bind_to.generic, sizeof(params->bind_to)), "Couldn't bind server to IP address" ); - SERVER_ERROR_ON_FAILURE( + FATAL_IF_NEGATIVE( listen(params->server_fd, params->tcp_backlog), "Couldn't listen on server socket" ); @@ -121,7 +121,7 @@ int tryjoin_client_thread( struct client_tbl_entry *entry, int (*joinfunc)(pthre if (joinfunc(entry->thread, &status) != 0) { if (errno != EBUSY) - SERVER_ERROR_ON_FAILURE(-1, "Problem with joining thread"); + FATAL_IF_NEGATIVE(-1, "Problem with joining thread"); } else { debug("nbd thread %p exited (%s) with status %ld", @@ -187,8 +187,6 @@ int cleanup_and_find_client_slot(struct server* params) break; } } - - if ( -1 == slot ) { debug( "No client slot found." ); } return slot; } @@ -230,7 +228,7 @@ int server_should_accept_client( debug( "Rejecting client %s: Access control error", s_client_address ); debug( "We %s have an acl, and default_deny is %s", (params->acl ? "do" : "do not"), - (params->default_deny ? "true" : "false") ); + (params->acl->default_deny ? "true" : "false") ); write(client_fd, "Access control error", 20); return 0; } @@ -263,6 +261,7 @@ void accept_nbd_client( slot = cleanup_and_find_client_slot(params); if (slot < 0) { + warn("too many clients to accept connection"); write(client_fd, "Too many clients", 16); close(client_fd); return; @@ -297,6 +296,8 @@ int server_is_closed(struct server* serve) void server_close_clients( struct server *params ) { NULLCHECK(params); + + info("closing all clients"); int i, j; struct client_tbl_entry *entry; @@ -318,6 +319,7 @@ void server_close_clients( struct server *params ) void serve_accept_loop(struct server* params) { NULLCHECK( params ); + info("accept loop starting"); while (1) { int activity_fd, client_fd; union mysockaddr client_address; @@ -330,7 +332,7 @@ void serve_accept_loop(struct server* params) if (params->control_socket_name) FD_SET(params->control_fd, &fds); - SERVER_ERROR_ON_FAILURE(select(FD_SETSIZE, &fds, + FATAL_IF_NEGATIVE(select(FD_SETSIZE, &fds, NULL, NULL, NULL), "select() failed"); if ( self_pipe_fd_isset( params->close_signal, &fds ) ){ @@ -343,11 +345,11 @@ void serve_accept_loop(struct server* params) client_fd = accept(activity_fd, &client_address.generic, &socklen); if (activity_fd == params->server_fd) { - debug("Accepted nbd client socket"); + info("Accepted nbd client socket"); accept_nbd_client(params, client_fd, &client_address); } if (activity_fd == params->control_fd) { - debug("Accepted control client socket"); + info("Accepted control client socket"); accept_control_connection(params, client_fd, &client_address); } @@ -364,10 +366,10 @@ void serve_init_allocation_map(struct server* params) int fd = open(params->filename, O_RDONLY); off64_t size; - SERVER_ERROR_ON_FAILURE(fd, "Couldn't open %s", params->filename); + FATAL_IF_NEGATIVE(fd, "Couldn't open %s", params->filename); size = lseek64(fd, 0, SEEK_END); params->size = size; - SERVER_ERROR_ON_FAILURE(size, "Couldn't find size of %s", + FATAL_IF_NEGATIVE(size, "Couldn't find size of %s", params->filename); params->allocation_map = build_allocation_map(fd, size, block_allocation_resolution); @@ -379,34 +381,43 @@ void serve_init_allocation_map(struct server* params) void serve_signal_close( struct server * serve ) { NULLCHECK( serve ); + info("signalling close"); self_pipe_signal( serve->close_signal ); } /** Closes sockets, frees memory and waits for all client threads to finish */ -void serve_cleanup(struct server* params) +void serve_cleanup(struct server* params, int fatal) { NULLCHECK( params ); + + info("cleaning up"); int i; - close(params->server_fd); - close(params->control_fd); + if (params->server_fd) + close(params->server_fd); + if (params->control_fd) + close(params->control_fd); if (params->acl) free(params->acl); - //free(params->filename); if (params->control_socket_name) - //free(params->control_socket_name); + ; pthread_mutex_destroy(¶ms->l_io); if (params->proxy_fd); close(params->proxy_fd); - self_pipe_destroy( params->close_signal ); + if (params->close_signal) + self_pipe_destroy( params->close_signal ); - free(params->allocation_map); + if (params->allocation_map) + free(params->allocation_map); - if (params->mirror) - debug("mirror thread running! this should not happen!"); + if (params->mirror) { + pthread_t mirror_t = params->mirror->thread; + params->mirror->signal_abandon = 1; + pthread_join(mirror_t, NULL); + } for (i=0; i < MAX_NBD_CLIENTS; i++) { void* status; @@ -422,18 +433,19 @@ void serve_cleanup(struct server* params) void do_serve(struct server* params) { NULLCHECK( params ); - + + error_set_handler((cleanup_handler*) serve_cleanup, params); pthread_mutex_init(¶ms->l_io, NULL); params->close_signal = self_pipe_create(); if ( NULL == params->close_signal) { - SERVER_ERROR( "close signal creation failed" ); + fatal( "close signal creation failed" ); } serve_open_server_socket(params); serve_open_control_socket(params); serve_init_allocation_map(params); serve_accept_loop(params); - serve_cleanup(params); + serve_cleanup(params, 0); } diff --git a/src/serve.h b/src/serve.h index f96222b..317cc2b 100644 --- a/src/serve.h +++ b/src/serve.h @@ -22,6 +22,8 @@ enum mirror_finish_action { struct mirror_status { pthread_t thread; + /* set to 1, then join thread to make mirror terminate early */ + int signal_abandon; int client; char *filename; off64_t max_bytes_per_second; diff --git a/src/util.c b/src/util.c index 32729b4..5b5e8e4 100644 --- a/src/util.c +++ b/src/util.c @@ -9,47 +9,43 @@ #include "util.h" -static pthread_t main_thread; -static int global_debug; +pthread_key_t cleanup_handler_key; + +int log_level = 1; void error_init() { - main_thread = pthread_self(); + pthread_key_create(&cleanup_handler_key, free); } -void error(int consult_errno, int fatal, int close_socket, pthread_mutex_t* unlock, const char* format, ...) +void error_handler(int fatal) +{ + DECLARE_ERROR_CONTEXT(context); + + if (!context) { + pthread_exit((void*) 1); + } + + longjmp(context->jmp, 1); +} + +void mylog(int line_level, const char* format, ...) { va_list argptr; - fprintf(stderr, "*** "); + if (line_level < log_level) + return; va_start(argptr, format); vfprintf(stderr, format, argptr); va_end(argptr); - - if (consult_errno) { - fprintf(stderr, " (errno=%d, %s)", errno, strerror(errno)); - } - - if (close_socket) { close(close_socket); } - if (unlock) { pthread_mutex_unlock(unlock); } - fprintf(stderr, "\n"); - - if (fatal || pthread_equal(pthread_self(), main_thread)) { - exit(1); - } - else { - fprintf(stderr, "Killing Thread\n"); - pthread_exit((void*) 1); - } } void* xrealloc(void* ptr, size_t size) { void* p = realloc(ptr, size); - if (p == NULL) - SERVER_ERROR("couldn't xrealloc %d bytes", size); + FATAL_IF_NULL(p, "couldn't xrealloc %d bytes", ptr ? "realloc" : "malloc", size); return p; } @@ -60,26 +56,3 @@ void* xmalloc(size_t size) return p; } - -void set_debug(int value) { - global_debug = value; -} - -#ifdef DEBUG -# include -# include - -void debug(const char *msg, ...) { - va_list argp; - va_start( argp, msg ); - - if ( global_debug ) { - fprintf(stderr, "%08x %4d: ", (int) pthread_self(), (int) clock() ); - vfprintf(stderr, msg, argp); - fprintf(stderr, "\n"); - } - - va_end( argp ); -} -#endif - diff --git a/src/util.h b/src/util.h index edd4386..86897b2 100644 --- a/src/util.h +++ b/src/util.h @@ -3,36 +3,106 @@ #include #include - -void error_init(); - -void error(int consult_errno, int fatal, int close_socket, pthread_mutex_t* unlock, const char* format, ...); +#include +#include void* xrealloc(void* ptr, size_t size); - void* xmalloc(size_t size); -void set_debug(int value); +typedef void (cleanup_handler)(void* /* context */, int /* is fatal? */); + +/* set from 0 - 5 depending on what level of verbosity you want */ +extern int log_level; + +/* set up the error globals */ +void error_init(); + +/* error_set_handler must be a macro not a function due to setjmp stack rules */ +#include + +struct error_handler_context { + jmp_buf jmp; + cleanup_handler* handler; + void* data; +}; + +#define DECLARE_ERROR_CONTEXT(name) \ + struct error_handler_context *name = (struct error_handler_context*) \ + pthread_getspecific(cleanup_handler_key) + +/* clean up with the given function & data when error_handler() is invoked, + * non-fatal errors will also return here (if that's dangerous, use fatal() + * instead of error()). + * + * error handlers are thread-local, so you need to call this when starting a + * new thread. + */ +extern pthread_key_t cleanup_handler_key; +#define error_set_handler(cleanfn, cleandata) \ +{ \ + DECLARE_ERROR_CONTEXT(old); \ + struct error_handler_context *context = \ + xmalloc(sizeof(struct error_handler_context)); \ + context->handler = (cleanfn); \ + context->data = (cleandata); \ + \ + switch (setjmp(context->jmp)) \ + { \ + case 0: /* setup time */ \ + if (old) \ + free(old); \ + pthread_setspecific(cleanup_handler_key, context); \ + break; \ + case 1: /* fatal error, terminate thread */ \ + context->handler(context->data, 1); \ + pthread_exit((void*) 1); \ + abort(); \ + case 2: /* non-fatal error, return to context of error handler setup */ \ + context->handler(context->data, 0); \ + default: \ + abort(); \ + } \ +} + + +/* invoke the error handler - longjmps away, don't use directly */ +void error_handler(int fatal); + +/* mylog a line at the given level (0 being most verbose) */ +void mylog(int line_level, const char* format, ...); + #ifdef DEBUG -void debug(const char*msg, ...); +# define debug(msg, ...) mylog(0, "%s:%d: " msg, __FILE__, __LINE__, ##__VA_ARGS__) #else -/* no-op */ -# define debug( msg, ...) +# define debug(msg, ...) /* no-op */ #endif -#define CLIENT_ERROR(msg, ...) \ - error(0, 0, client->socket, &client->serve->l_io, msg, ##__VA_ARGS__) -#define CLIENT_ERROR_ON_FAILURE(test, msg, ...) \ - if (test < 0) { error(1, 0, client->socket, &client->serve->l_io, msg, ##__VA_ARGS__); } +/* informational message, not expected to be compiled out */ +#define info(msg, ...) mylog(1, "%s:%d: " msg, __FILE__, __LINE__, ##__VA_ARGS__) -#define SERVER_ERROR(msg, ...) \ - error(0, 1, 0, NULL, msg, ##__VA_ARGS__) -#define SERVER_ERROR_ON_FAILURE(test, msg, ...) \ - if (test < 0) { error(1, 1, 0, NULL, msg, ##__VA_ARGS__); } +/* messages that might indicate a problem */ +#define warn(msg, ...) mylog(2, "%s:%d: " msg, __FILE__, __LINE__, ##__VA_ARGS__) +/* mylog a message and invoke the error handler to recover */ +#define error(msg, ...) { \ + mylog(3, "*** %s:%d: " msg, __FILE__, __LINE__, ##__VA_ARGS__); \ + error_handler(0); \ +} -#define NULLCHECK(x) \ - do { if ( NULL == (x) ) { SERVER_ERROR( "Null " #x "." ); } } while(0) +/* mylog a message and invoke the error handler to kill the current thread */ +#define fatal(msg, ...) { \ + mylog(4, "*** %s:%d: " msg, __FILE__, __LINE__, ##__VA_ARGS__); \ + error_handler(1); \ +} + +#define ERROR_IF_NULL(value, msg, ...) if (NULL == value) error(msg " (errno=%d, %s)", ##__VA_ARGS__, errno, strerror(errno)) +#define ERROR_IF_NEGATIVE(value, msg, ...) if (value < 0) error(msg, ##__VA_ARGS__) +#define ERROR_IF_ZERO(value, msg, ...) if (0 == value) error(msg, ##__VA_ARGS__) +#define FATAL_IF_NULL(value, msg, ...) if (NULL == value) fatal(msg, ##__VA_ARGS__) +#define FATAL_IF_NEGATIVE(value, msg, ...) if (value < 0) fatal(msg " (errno=%d, %s)", ##__VA_ARGS__, errno, strerror(errno)) +#define FATAL_IF_ZERO(value, msg, ...) if (0 == value) fatal(msg, ##__VA_ARGS__) + +#define NULLCHECK(value) FATAL_IF_NULL(value, "BUG: " #value " is null") #endif diff --git a/tests/check_acl.c b/tests/check_acl.c index 06e4983..34263ef 100644 --- a/tests/check_acl.c +++ b/tests/check_acl.c @@ -2,6 +2,7 @@ #include #include "acl.h" +#include "util.h" START_TEST( test_null_acl ) { @@ -110,11 +111,11 @@ Suite* acl_suite() int main(void) { - set_debug(1); int number_failed; Suite *s = acl_suite(); SRunner *sr = srunner_create(s); srunner_run_all(sr, CK_NORMAL); + log_level = 0; number_failed = srunner_ntests_failed(sr); srunner_free(sr); return (number_failed == 0) ? 0 : 1;