Split code out into separate compilation units (first pass, anyway).

This commit is contained in:
Matthew Bloch
2012-05-17 20:14:22 +01:00
parent aec90e5244
commit 0432fef8f5
11 changed files with 790 additions and 683 deletions

View File

@@ -1,5 +1,7 @@
DEBUG = true
SOURCES = %w( flexnbd ioutil readwrite serve util )
OBJECTS = SOURCES.map { |s| "#{s}.o" }
LIBS = %w( pthread )
CCFLAGS = %w( -Wall )
LDFLAGS = []
@@ -12,10 +14,11 @@ end
rule 'default' => 'flexnbd'
rule 'flexnbd' => 'flexnbd.o' do |t|
rule 'flexnbd' => OBJECTS do |t|
sh "gcc #{LDFLAGS.join(' ')} "+
LIBS.map { |l| "-l#{l}" }.join(" ")+
" -o #{t.name} #{t.source}"
" -o #{t.name} "+
t.sources.join(" ")
end
rule '.o' => '.c' do |t|
@@ -23,5 +26,5 @@ rule '.o' => '.c' do |t|
end
rule 'clean' do
sh "rm -f flexnbd.o flexnbd"
sh "rm -f flexnbd "+OBJECTS.join(" ")
end

39
bitset.h Normal file
View File

@@ -0,0 +1,39 @@
#ifndef __BITSET_H
#define __BITSET_H
#include <string.h>
static inline char char_with_bit_set(int num) {
return 1<<(num%8);
}
static inline int bit_is_set(char* b, int idx) {
return (b[idx/8] & char_with_bit_set(idx)) != 0;
}
static inline int bit_is_clear(char* b, int idx) {
return !bit_is_set(b, idx);
}
static inline void bit_set(char* b, int idx) {
b[idx/8] &= char_with_bit_set(idx);
}
static inline void bit_clear(char* b, int idx) {
b[idx/8] &= ~char_with_bit_set(idx);
}
static inline void bit_set_range(char* b, int from, int len) {
for (; from%8 != 0 && len > 0; len--)
bit_set(b, from++);
if (len >= 8)
memset(b+(from/8), 255, len/8);
for (; len > 0; len--)
bit_set(b, from++);
}
static inline void bit_clear_range(char* b, int from, int len) {
for (; from%8 != 0 && len > 0; len--)
bit_clear(b, from++);
if (len >= 8)
memset(b+(from/8), 0, len/8);
for (; len > 0; len--)
bit_clear(b, from++);
}
#endif

690
flexnbd.c
View File

@@ -1,57 +1,15 @@
#define _LARGEFILE64_SOURCE
#define _GNU_SOURCE
#include "params.h"
#include "util.h"
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <stdarg.h>
#include <malloc.h>
#include <errno.h>
#include <endian.h>
#include <unistd.h>
#include <fcntl.h>
#include <pthread.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <arpa/inet.h>
#include <sys/mman.h>
#include <sys/sendfile.h>
#include <sys/stat.h>
#include <fcntl.h>
#include <unistd.h>
#include <sys/ioctl.h>
#include <linux/fs.h>
#include <linux/fiemap.h>
/* http://linux.derkeiler.com/Mailing-Lists/Kernel/2003-09/2332.html */
#define INIT_PASSWD "NBDMAGIC"
#define INIT_MAGIC 0x0000420281861253
#define REQUEST_MAGIC 0x25609513
#define REPLY_MAGIC 0x67446698
#define REQUEST_READ 0
#define REQUEST_WRITE 1
#define REQUEST_DISCONNECT 2
#include <linux/types.h>
struct nbd_init {
char passwd[8];
__be64 magic;
__be64 size;
char reserved[128];
};
struct nbd_request {
__be32 magic;
__be32 type; /* == READ || == WRITE */
char handle[8];
__be64 from;
__be32 len;
} __attribute__((packed));
struct nbd_reply {
__be32 magic;
__be32 error; /* 0 = ok, else error */
char handle[8]; /* handle you got from request */
};
void syntax()
{
@@ -65,638 +23,6 @@ void syntax()
exit(1);
}
static pthread_t server_thread_id;
void error(int consult_errno, int close_socket, const char* format, ...)
{
va_list argptr;
fprintf(stderr, "*** ");
va_start(argptr, format);
vfprintf(stderr, format, argptr);
va_end(argptr);
if (consult_errno) {
fprintf(stderr, " (errno=%d, %s)", errno, strerror(errno));
}
if (close_socket)
close(close_socket);
fprintf(stderr, "\n");
if (pthread_equal(pthread_self(), server_thread_id))
exit(1);
else
pthread_exit((void*) 1);
}
#ifndef DEBUG
# define debug(msg, ...)
#else
# include <sys/times.h>
# define debug(msg, ...) fprintf(stderr, "%08x %4d: " msg "\n" , \
(int) pthread_self(), (int) clock(), ##__VA_ARGS__)
#endif
#define CLIENT_ERROR(msg, ...) \
error(0, client->socket, msg, ##__VA_ARGS__)
#define CLIENT_ERROR_ON_FAILURE(test, msg, ...) \
if (test < 0) { error(1, client->socket, msg, ##__VA_ARGS__); }
#define SERVER_ERROR(msg, ...) \
error(0, 0, msg, ##__VA_ARGS__)
#define SERVER_ERROR_ON_FAILURE(test, msg, ...) \
if (test < 0) { error(1, 0, msg, ##__VA_ARGS__); }
void* xrealloc(void* ptr, size_t size)
{
void* p = realloc(ptr, size);
if (p == NULL)
SERVER_ERROR("couldn't xrealloc %d bytes", size);
return p;
}
void* xmalloc(size_t size)
{
void* p = xrealloc(NULL, size);
memset(p, 0, size);
return p;
}
union mysockaddr {
unsigned short family;
struct sockaddr generic;
struct sockaddr_in v4;
struct sockaddr_in6 v6;
};
struct ip_and_mask {
union mysockaddr ip;
int mask;
};
struct mode_serve_params {
union mysockaddr bind_to;
int acl_entries;
struct ip_and_mask** acl;
char* filename;
int tcp_backlog;
int server;
int threads;
pthread_mutex_t block_allocation_map_lock;
char* block_allocation_map;
};
struct mode_readwrite_params {
union mysockaddr connect_to;
off64_t from;
off64_t len;
int data_fd;
int client;
};
struct client_params {
int socket;
char* filename;
int fileno;
off64_t size;
char* mapped;
pthread_mutex_t block_allocation_map_lock;
char* block_allocation_map;
}};
union mode_params {
struct mode_serve_params serve;
struct mode_readwrite_params readwrite;
};
static inline int char_with_bit_set(int num) {
return 1<<(num%8);
}
static inline int bit_is_set(char* b, int idx) {
return (b[idx/8] & char_with_bit_set(idx)) != 0;
}
static inline int bit_is_clear(char* b, int idx) {
return !bit_is_set(b, idx);
}
static inline void bit_set(char* b, int idx) {
b[idx/8] &= char_with_bit_set(idx);
}
static inline void bit_clear(char* b, int idx) {
b[idx/8] &= ~char_with_bit_set(idx);
}
static inline void bit_set_range(char* b, int from, int len) {
for (; b%8 != 0 && len > 0; len--)
bit_set(b, from++, 1);
if (len >= 8)
memset(b+(from/8), 255, len/8);
for (; len > 0; len--)
bit_set(b, from++);
}
static inline void bit_clear_range(char* b, int from, int len) {
for (; b%8 != 0 && len > 0; len--)
bit_clear(b, from++, 1);
if (len >= 8)
memset(b+(from/8), 0, len/8);
for (; len > 0; len--)
bit_clear(b, from++);
}
char* build_allocation_map(int fd, off64_t size, int resolution)
{
char *allocation_map = xmalloc((size+resolution)/resolution);
struct fiemap *fiemap;
fiemap = (struct fiemap*) xmalloc(sizeof(struct fiemap));
fiemap->fm_start = from;
fiemap->fm_length = len;
fiemap->fm_flags = 0;
fiemap->fm_extent_count = 0;
fiemap->fm_mapped_extents = 0;
/* Find out how many extents there are */
if (ioctl(fd, FS_IOC_FIEMAP, fiemap) < 0)
return NULL;
/* Resize fiemap to allow us to read in the extents */
fiemap = (struct fiemap*)xrealloc(
fiemap,
sizeof(struct fiemap) + (
sizeof(struct fiemap_extent) *
fiemap->fm_mapped_extents
)
);
fiemap->fm_extent_count = fiemap->fm_mapped_extents;
fiemap->fm_mapped_extents = 0;
if (ioctl(fd, FS_IOC_FIEMAP, fiemap) < -1)
return NULL;
for (i=0;i<fiemap->fm_mapped_extents;i++)
bit_set_range(
allocation_map,
fiemap->fm_extents[i].fe_logical / resolution,
fiemap->fm_extents[i].fe_length / resolution
);
free(fiemap);
return allocation_map;
}
int writeloop(int filedes, const void *buffer, size_t size)
{
size_t written=0;
while (written < size) {
size_t result = write(filedes, buffer+written, size-written);
if (result == -1)
return -1;
written += result;
}
return 0;
}
int readloop(int filedes, void *buffer, size_t size)
{
size_t readden=0;
while (readden < size) {
size_t result = read(filedes, buffer+readden, size-readden);
if (result == 0 /* EOF */ || result == -1 /* error */)
return -1;
readden += result;
}
return 0;
}
int sendfileloop(int out_fd, int in_fd, off64_t *offset, size_t count)
{
size_t sent=0;
while (sent < count) {
size_t result = sendfile64(out_fd, in_fd, offset+sent, count-sent);
if (result == -1)
return -1;
sent += result;
}
return 0;
}
int splice_via_pipe_loop(int fd_in, int fd_out, size_t len)
{
int pipefd[2];
size_t spliced=0;
if (pipe(pipefd) == -1)
return -1;
while (spliced < len) {
size_t r1,r2;
r1 = splice(fd_in, NULL, pipefd[1], NULL, len-spliced, 0);
if (r1 <= 0)
break;
r2 = splice(pipefd[0], NULL, fd_out, NULL, r1, 0);
if (r1 != r2)
break;
spliced += r1;
}
close(pipefd[0]);
close(pipefd[1]);
return spliced < len ? -1 : 0;
}
int client_serve_request(struct client_params* client)
{
off64_t offset;
struct nbd_request request;
struct nbd_reply reply;
struct unallocated_block** unallocated;
if (readloop(client->socket, &request, sizeof(request)) == -1) {
if (errno == 0) {
debug("EOF reading request");
return 1; /* neat point to close the socket */
}
else {
CLIENT_ERROR_ON_FAILURE(-1, "Error reading request");
}
}
reply.magic = htobe32(REPLY_MAGIC);
reply.error = htobe32(0);
memcpy(reply.handle, request.handle, 8);
debug("request type %d", be32toh(request.type));
if (be32toh(request.magic) != REQUEST_MAGIC)
CLIENT_ERROR("Bad magic %08x", be32toh(request.magic));
switch (be32toh(request.type))
{
case REQUEST_READ:
case REQUEST_WRITE:
/* check it's not out of range */
if (be64toh(request.from) < 0 ||
be64toh(request.from)+be32toh(request.len) > client->size) {
debug("request read %ld+%d out of range",
be64toh(request.from),
be32toh(request.len)
);
reply.error = htobe32(1);
write(client->socket, &reply, sizeof(reply));
return 0;
}
break;
case REQUEST_DISCONNECT:
debug("request disconnect");
return 1;
default:
CLIENT_ERROR("Unknown request %08x", be32toh(request.type));
}
switch (be32toh(request.type))
{
case REQUEST_READ:
debug("request read %ld+%d", be64toh(request.from), be32toh(request.len));
write(client->socket, &reply, sizeof(reply));
offset = be64toh(request.from);
CLIENT_ERROR_ON_FAILURE(
sendfileloop(
client->socket,
client->fileno,
&offset,
be32toh(request.len)
),
"sendfile failed from=%ld, len=%d",
offset,
be32toh(request.len)
);
break;
case REQUEST_WRITE:
debug("request write %ld+%d", be64toh(request.from), be32toh(request.len));
#ifdef _LINUX_FIEMAP_H
unallocated = read_unallocated_blocks(
client->fileno,
be64toh(request.from),
be32toh(request.len)
);
if (unallocated == NULL)
CLIENT_ERROR("Couldn't read unallocated blocks list");
CLIENT_ERROR_ON_FAILURE(
read_from_socket_avoiding_holes(
client->socket,
);
free(fiemap);
#else
CLIENT_ERROR_ON_FAILURE(
readloop(
client->socket,
client->mapped + be64toh(request.from),
be32toh(request.len)
),
"read failed from=%ld, len=%d",
be64toh(request.from),
be32toh(request.len)
);
#endif
write(client->socket, &reply, sizeof(reply));
break;
}
return 0;
}
void client_open_file(struct client_params* client)
{
client->fileno = open(client->filename, O_RDWR|O_DIRECT|O_SYNC);
CLIENT_ERROR_ON_FAILURE(client->fileno, "Couldn't open %s",
client->filename);
client->size = lseek64(client->fileno, 0, SEEK_END);
CLIENT_ERROR_ON_FAILURE(client->fileno, "Couldn't seek to end of %s",
client->filename);
client->mapped = mmap64(NULL, client->size, PROT_READ|PROT_WRITE,
MAP_SHARED, client->fileno, 0);
CLIENT_ERROR_ON_FAILURE((long) client->mapped, "Couldn't map file %s",
client->filename);
debug("opened %s size %ld on fd %d @ %p", client->filename, client->size, client->fileno, client->mapped);
}
void client_send_hello(struct client_params* client)
{
struct nbd_init init;
memcpy(init.passwd, INIT_PASSWD, sizeof(INIT_PASSWD));
init.magic = htobe64(INIT_MAGIC);
init.size = htobe64(client->size);
memset(init.reserved, 0, 128);
CLIENT_ERROR_ON_FAILURE(
writeloop(client->socket, &init, sizeof(init)),
"Couldn't send hello"
);
}
void* client_serve(void* client_uncast)
{
struct client_params* client = (struct client_params*) client_uncast;
client_open_file(client);
client_send_hello(client);
while (client_serve_request(client) == 0)
;
CLIENT_ERROR_ON_FAILURE(
close(client->socket),
"Couldn't close socket %d",
client->socket
);
free(client);
return NULL;
}
static int testmasks[9] = { 0,128,192,224,240,248,252,254,255 };
int is_included_in_acl(int list_length, struct ip_and_mask** list, struct sockaddr* test)
{
int i;
for (i=0; i < list_length; i++) {
struct ip_and_mask *entry = list[i];
int testbits;
char *raw_address1, *raw_address2;
debug("checking acl entry %d", i);
if (test->sa_family != entry->ip.family)
continue;
if (test->sa_family == AF_INET) {
raw_address1 = (char*)
&((struct sockaddr_in*) test)->sin_addr;
raw_address2 = (char*) &entry->ip.v4.sin_addr;
}
else if (test->sa_family == AF_INET6) {
raw_address1 = (char*)
&((struct sockaddr_in6*) test)->sin6_addr;
raw_address2 = (char*) &entry->ip.v6.sin6_addr;
}
for (testbits = entry->mask; testbits > 0; testbits -= 8) {
debug("testbits=%d, c1=%d, c2=%d", testbits, raw_address1[0], raw_address2[0]);
if (testbits >= 8) {
if (raw_address1[0] != raw_address2[0])
goto no_match;
}
else {
if ((raw_address1[0] & testmasks[testbits%8]) !=
(raw_address2[0] & testmasks[testbits%8]) )
goto no_match;
}
raw_address1++;
raw_address2++;
}
return 1;
no_match: ;
debug("no match");
}
return 0;
}
void serve_open_socket(struct mode_serve_params* params)
{
params->server = socket(PF_INET, SOCK_STREAM, 0);
SERVER_ERROR_ON_FAILURE(params->server,
"Couldn't create server socket");
SERVER_ERROR_ON_FAILURE(
bind(params->server, &params->bind_to.generic,
sizeof(params->bind_to.generic)),
"Couldn't bind server to IP address"
);
SERVER_ERROR_ON_FAILURE(
listen(params->server, params->tcp_backlog),
"Couldn't listen on server socket"
);
}
void serve_accept_loop(struct mode_serve_params* params)
{
while (1) {
pthread_t client_thread;
struct sockaddr client_address;
struct client_params* client_params;
socklen_t socket_length=0;
int client_socket = accept(params->server, &client_address,
&socket_length);
SERVER_ERROR_ON_FAILURE(client_socket, "accept() failed");
if (params->acl &&
!is_included_in_acl(params->acl_entries, params->acl, &client_address)) {
write(client_socket, "Access control error", 20);
close(client_socket);
continue;
}
client_params = xmalloc(sizeof(struct client_params));
client_params->socket = client_socket;
client_params->filename = params->filename;
client_params->block_allocation_map =
params->block_allocation_map;
client_params->block_allocation_map_lock =
params->block_allocation_map_lock;
client_thread = pthread_create(&client_thread, NULL,
client_serve, client_params);
SERVER_ERROR_ON_FAILURE(client_thread,
"Failed to create client thread");
/* FIXME: keep track of them? */
/* FIXME: maybe shouldn't be fatal? */
}
}
void do_serve(struct mode_serve_params* params)
{
serve_open_socket(params);
serve_accept_loop(params);
}
int socket_connect(struct sockaddr* to)
{
int fd = socket(PF_INET, SOCK_STREAM, 0);
SERVER_ERROR_ON_FAILURE(fd, "Couldn't create client socket");
SERVER_ERROR_ON_FAILURE(connect(fd, to, sizeof(*to)),
"connect failed");
return fd;
}
off64_t socket_nbd_read_hello(int fd)
{
struct nbd_init init;
SERVER_ERROR_ON_FAILURE(readloop(fd, &init, sizeof(init)),
"Couldn't read init");
if (strncmp(init.passwd, INIT_PASSWD, 8) != 0)
SERVER_ERROR("wrong passwd");
if (be64toh(init.magic) != INIT_MAGIC)
SERVER_ERROR("wrong magic (%x)", be64toh(init.magic));
return be64toh(init.size);
}
void fill_request(struct nbd_request *request, int type, int from, int len)
{
request->magic = htobe32(REQUEST_MAGIC);
request->type = htobe32(type);
((int*) request->handle)[0] = rand();
((int*) request->handle)[1] = rand();
request->from = htobe64(from);
request->len = htobe32(len);
}
void read_reply(int fd, struct nbd_request *request, struct nbd_reply *reply)
{
SERVER_ERROR_ON_FAILURE(readloop(fd, reply, sizeof(*reply)),
"Couldn't read reply");
if (be32toh(reply->magic) != REPLY_MAGIC)
SERVER_ERROR("Reply magic incorrect (%p)", reply->magic);
if (be32toh(reply->error) != 0)
SERVER_ERROR("Server replied with error %d", reply->error);
if (strncmp(request->handle, reply->handle, 8) != 0)
SERVER_ERROR("Did not reply with correct handle");
}
void socket_nbd_read(int fd, off64_t from, int len, int out_fd, void* out_buf)
{
struct nbd_request request;
struct nbd_reply reply;
fill_request(&request, REQUEST_READ, from, len);
SERVER_ERROR_ON_FAILURE(writeloop(fd, &request, sizeof(request)),
"Couldn't write request");
read_reply(fd, &request, &reply);
if (out_buf) {
SERVER_ERROR_ON_FAILURE(readloop(fd, out_buf, len),
"Read failed");
}
else {
SERVER_ERROR_ON_FAILURE(
splice_via_pipe_loop(fd, out_fd, len),
"Splice failed"
);
}
}
void socket_nbd_write(int fd, off64_t from, int len, int in_fd, void* in_buf)
{
struct nbd_request request;
struct nbd_reply reply;
fill_request(&request, REQUEST_WRITE, from, len);
SERVER_ERROR_ON_FAILURE(writeloop(fd, &request, sizeof(request)),
"Couldn't write request");
if (in_buf) {
SERVER_ERROR_ON_FAILURE(writeloop(fd, in_buf, len),
"Write failed");
}
else {
SERVER_ERROR_ON_FAILURE(
splice_via_pipe_loop(in_fd, fd, len),
"Splice failed"
);
}
read_reply(fd, &request, &reply);
}
#define CHECK_RANGE(error_type) { \
off64_t size = socket_nbd_read_hello(params->client); \
if (params->from < 0 || (params->from + params->len) >= size) \
SERVER_ERROR(error_type \
" request %d+%d is out of range given size %d", \
params->from, params->len, size\
); \
}
void do_read(struct mode_readwrite_params* params)
{
params->client = socket_connect(&params->connect_to.generic);
CHECK_RANGE("read");
socket_nbd_read(params->client, params->from, params->len,
params->data_fd, NULL);
close(params->client);
}
void do_write(struct mode_readwrite_params* params)
{
params->client = socket_connect(&params->connect_to.generic);
CHECK_RANGE("write");
socket_nbd_write(params->client, params->from, params->len,
params->data_fd, NULL);
close(params->client);
}
#define IS_IP_VALID_CHAR(x) ( ((x) >= '0' && (x) <= '9' ) || \
((x) >= 'a' && (x) <= 'f') || \
((x) >= 'A' && (x) <= 'F' ) || \
@@ -862,6 +188,10 @@ void params_readwrite(
}
}
void do_serve(struct mode_serve_params* params);
void do_read(struct mode_readwrite_params* params);
void do_write(struct mode_readwrite_params* params);
void mode(char* mode, int argc, char **argv)
{
union mode_params params;
@@ -901,7 +231,7 @@ void mode(char* mode, int argc, char **argv)
int main(int argc, char** argv)
{
server_thread_id = pthread_self();
error_init();
if (argc < 2)
syntax();

123
ioutil.c Normal file
View File

@@ -0,0 +1,123 @@
#define _LARGEFILE64_SOURCE
#define _GNU_SOURCE
#include <sys/mman.h>
#include <sys/sendfile.h>
#include <sys/ioctl.h>
#include <sys/types.h>
#include <linux/fs.h>
#include <linux/fiemap.h>
#include <stdlib.h>
#include <unistd.h>
#include <fcntl.h>
#include "util.h"
#include "bitset.h"
char* build_allocation_map(int fd, off64_t size, int resolution)
{
int i;
char *allocation_map = xmalloc((size+resolution)/resolution);
struct fiemap *fiemap;
fiemap = (struct fiemap*) xmalloc(sizeof(struct fiemap));
fiemap->fm_start = 0;
fiemap->fm_length = size;
fiemap->fm_flags = 0;
fiemap->fm_extent_count = 0;
fiemap->fm_mapped_extents = 0;
/* Find out how many extents there are */
if (ioctl(fd, FS_IOC_FIEMAP, fiemap) < 0)
return NULL;
/* Resize fiemap to allow us to read in the extents */
fiemap = (struct fiemap*)xrealloc(
fiemap,
sizeof(struct fiemap) + (
sizeof(struct fiemap_extent) *
fiemap->fm_mapped_extents
)
);
fiemap->fm_extent_count = fiemap->fm_mapped_extents;
fiemap->fm_mapped_extents = 0;
if (ioctl(fd, FS_IOC_FIEMAP, fiemap) < -1)
return NULL;
for (i=0;i<fiemap->fm_mapped_extents;i++)
bit_set_range(
allocation_map,
fiemap->fm_extents[i].fe_logical / resolution,
fiemap->fm_extents[i].fe_length / resolution
);
free(fiemap);
return allocation_map;
}
int writeloop(int filedes, const void *buffer, size_t size)
{
size_t written=0;
while (written < size) {
size_t result = write(filedes, buffer+written, size-written);
if (result == -1)
return -1;
written += result;
}
return 0;
}
int readloop(int filedes, void *buffer, size_t size)
{
size_t readden=0;
while (readden < size) {
size_t result = read(filedes, buffer+readden, size-readden);
if (result == 0 /* EOF */ || result == -1 /* error */)
return -1;
readden += result;
}
return 0;
}
int sendfileloop(int out_fd, int in_fd, off64_t *offset, size_t count)
{
size_t sent=0;
while (sent < count) {
size_t result = sendfile64(out_fd, in_fd, offset+sent, count-sent);
if (result == -1)
return -1;
sent += result;
}
return 0;
}
int splice_via_pipe_loop(int fd_in, int fd_out, size_t len)
{
int pipefd[2];
size_t spliced=0;
if (pipe(pipefd) == -1)
return -1;
while (spliced < len) {
size_t r1,r2;
r1 = splice(fd_in, NULL, pipefd[1], NULL, len-spliced, 0);
if (r1 <= 0)
break;
r2 = splice(pipefd[0], NULL, fd_out, NULL, r1, 0);
if (r1 != r2)
break;
spliced += r1;
}
close(pipefd[0]);
close(pipefd[1]);
return spliced < len ? -1 : 0;
}

14
ioutil.h Normal file
View File

@@ -0,0 +1,14 @@
#ifndef __IOUTIL_H
#define __IOUTIL_H
#include "params.h"
char* build_allocation_map(int fd, off64_t size, int resolution);
int writeloop(int filedes, const void *buffer, size_t size);
int readloop(int filedes, void *buffer, size_t size);
int sendfileloop(int out_fd, int in_fd, off64_t *offset, size_t count);
int splice_via_pipe_loop(int fd_in, int fd_out, size_t len);
#endif

36
nbdtypes.h Normal file
View File

@@ -0,0 +1,36 @@
#ifndef __NBDTYPES_H
#define __NBDTYPES_H
/* http://linux.derkeiler.com/Mailing-Lists/Kernel/2003-09/2332.html */
#define INIT_PASSWD "NBDMAGIC"
#define INIT_MAGIC 0x0000420281861253
#define REQUEST_MAGIC 0x25609513
#define REPLY_MAGIC 0x67446698
#define REQUEST_READ 0
#define REQUEST_WRITE 1
#define REQUEST_DISCONNECT 2
#include <linux/types.h>
struct nbd_init {
char passwd[8];
__be64 magic;
__be64 size;
char reserved[128];
};
struct nbd_request {
__be32 magic;
__be32 type; /* == READ || == WRITE */
char handle[8];
__be64 from;
__be32 len;
} __attribute__((packed));
struct nbd_reply {
__be32 magic;
__be32 error; /* 0 = ok, else error */
char handle[8]; /* handle you got from request */
};
#endif

64
params.h Normal file
View File

@@ -0,0 +1,64 @@
#ifndef __PARAMS_H
#define __PARAMS_H
#define _GNU_SOURCE
#define _LARGEFILE64_SOURCE
#include <sys/types.h>
#include <sys/socket.h>
#include <arpa/inet.h>
#include <unistd.h>
union mysockaddr {
unsigned short family;
struct sockaddr generic;
struct sockaddr_in v4;
struct sockaddr_in6 v6;
};
struct ip_and_mask {
union mysockaddr ip;
int mask;
};
struct mode_serve_params {
union mysockaddr bind_to;
int acl_entries;
struct ip_and_mask** acl;
char* filename;
int tcp_backlog;
int server;
int threads;
pthread_mutex_t block_allocation_map_lock;
char* block_allocation_map;
};
struct mode_readwrite_params {
union mysockaddr connect_to;
off64_t from;
off64_t len;
int data_fd;
int client;
};
struct client_params {
int socket;
char* filename;
int fileno;
off64_t size;
char* mapped;
pthread_mutex_t block_allocation_map_lock;
char* block_allocation_map;
};
union mode_params {
struct mode_serve_params serve;
struct mode_readwrite_params readwrite;
};
#endif

124
readwrite.c Normal file
View File

@@ -0,0 +1,124 @@
#include "nbdtypes.h"
#include "ioutil.h"
#include "util.h"
#include "params.h"
#include <stdlib.h>
#include <string.h>
#include <sys/socket.h>
int socket_connect(struct sockaddr* to)
{
int fd = socket(PF_INET, SOCK_STREAM, 0);
SERVER_ERROR_ON_FAILURE(fd, "Couldn't create client socket");
SERVER_ERROR_ON_FAILURE(connect(fd, to, sizeof(*to)),
"connect failed");
return fd;
}
off64_t socket_nbd_read_hello(int fd)
{
struct nbd_init init;
SERVER_ERROR_ON_FAILURE(readloop(fd, &init, sizeof(init)),
"Couldn't read init");
if (strncmp(init.passwd, INIT_PASSWD, 8) != 0)
SERVER_ERROR("wrong passwd");
if (be64toh(init.magic) != INIT_MAGIC)
SERVER_ERROR("wrong magic (%x)", be64toh(init.magic));
return be64toh(init.size);
}
void fill_request(struct nbd_request *request, int type, int from, int len)
{
request->magic = htobe32(REQUEST_MAGIC);
request->type = htobe32(type);
((int*) request->handle)[0] = rand();
((int*) request->handle)[1] = rand();
request->from = htobe64(from);
request->len = htobe32(len);
}
void read_reply(int fd, struct nbd_request *request, struct nbd_reply *reply)
{
SERVER_ERROR_ON_FAILURE(readloop(fd, reply, sizeof(*reply)),
"Couldn't read reply");
if (be32toh(reply->magic) != REPLY_MAGIC)
SERVER_ERROR("Reply magic incorrect (%p)", reply->magic);
if (be32toh(reply->error) != 0)
SERVER_ERROR("Server replied with error %d", reply->error);
if (strncmp(request->handle, reply->handle, 8) != 0)
SERVER_ERROR("Did not reply with correct handle");
}
void socket_nbd_read(int fd, off64_t from, int len, int out_fd, void* out_buf)
{
struct nbd_request request;
struct nbd_reply reply;
fill_request(&request, REQUEST_READ, from, len);
SERVER_ERROR_ON_FAILURE(writeloop(fd, &request, sizeof(request)),
"Couldn't write request");
read_reply(fd, &request, &reply);
if (out_buf) {
SERVER_ERROR_ON_FAILURE(readloop(fd, out_buf, len),
"Read failed");
}
else {
SERVER_ERROR_ON_FAILURE(
splice_via_pipe_loop(fd, out_fd, len),
"Splice failed"
);
}
}
void socket_nbd_write(int fd, off64_t from, int len, int in_fd, void* in_buf)
{
struct nbd_request request;
struct nbd_reply reply;
fill_request(&request, REQUEST_WRITE, from, len);
SERVER_ERROR_ON_FAILURE(writeloop(fd, &request, sizeof(request)),
"Couldn't write request");
if (in_buf) {
SERVER_ERROR_ON_FAILURE(writeloop(fd, in_buf, len),
"Write failed");
}
else {
SERVER_ERROR_ON_FAILURE(
splice_via_pipe_loop(in_fd, fd, len),
"Splice failed"
);
}
read_reply(fd, &request, &reply);
}
#define CHECK_RANGE(error_type) { \
off64_t size = socket_nbd_read_hello(params->client); \
if (params->from < 0 || (params->from + params->len) >= size) \
SERVER_ERROR(error_type \
" request %d+%d is out of range given size %d", \
params->from, params->len, size\
); \
}
void do_read(struct mode_readwrite_params* params)
{
params->client = socket_connect(&params->connect_to.generic);
CHECK_RANGE("read");
socket_nbd_read(params->client, params->from, params->len,
params->data_fd, NULL);
close(params->client);
}
void do_write(struct mode_readwrite_params* params)
{
params->client = socket_connect(&params->connect_to.generic);
CHECK_RANGE("write");
socket_nbd_write(params->client, params->from, params->len,
params->data_fd, NULL);
close(params->client);
}

283
serve.c Normal file
View File

@@ -0,0 +1,283 @@
#include "params.h"
#include "nbdtypes.h"
#include "ioutil.h"
#include "util.h"
#include <sys/types.h>
#include <sys/stat.h>
#include <sys/mman.h>
#include <fcntl.h>
#include <string.h>
#include <stdlib.h>
#include <errno.h>
int client_serve_request(struct client_params* client)
{
off64_t offset;
struct nbd_request request;
struct nbd_reply reply;
// struct unallocated_block** unallocated;
if (readloop(client->socket, &request, sizeof(request)) == -1) {
if (errno == 0) {
debug("EOF reading request");
return 1; /* neat point to close the socket */
}
else {
CLIENT_ERROR_ON_FAILURE(-1, "Error reading request");
}
}
reply.magic = htobe32(REPLY_MAGIC);
reply.error = htobe32(0);
memcpy(reply.handle, request.handle, 8);
debug("request type %d", be32toh(request.type));
if (be32toh(request.magic) != REQUEST_MAGIC)
CLIENT_ERROR("Bad magic %08x", be32toh(request.magic));
switch (be32toh(request.type))
{
case REQUEST_READ:
case REQUEST_WRITE:
/* check it's not out of range */
if (be64toh(request.from) < 0 ||
be64toh(request.from)+be32toh(request.len) > client->size) {
debug("request read %ld+%d out of range",
be64toh(request.from),
be32toh(request.len)
);
reply.error = htobe32(1);
write(client->socket, &reply, sizeof(reply));
return 0;
}
break;
case REQUEST_DISCONNECT:
debug("request disconnect");
return 1;
default:
CLIENT_ERROR("Unknown request %08x", be32toh(request.type));
}
switch (be32toh(request.type))
{
case REQUEST_READ:
debug("request read %ld+%d", be64toh(request.from), be32toh(request.len));
write(client->socket, &reply, sizeof(reply));
offset = be64toh(request.from);
CLIENT_ERROR_ON_FAILURE(
sendfileloop(
client->socket,
client->fileno,
&offset,
be32toh(request.len)
),
"sendfile failed from=%ld, len=%d",
offset,
be32toh(request.len)
);
break;
case REQUEST_WRITE:
debug("request write %ld+%d", be64toh(request.from), be32toh(request.len));
#ifdef _LINUX_FIEMAP_H
unallocated = read_unallocated_blocks(
client->fileno,
be64toh(request.from),
be32toh(request.len)
);
if (unallocated == NULL)
CLIENT_ERROR("Couldn't read unallocated blocks list");
CLIENT_ERROR_ON_FAILURE(
read_from_socket_avoiding_holes(
client->socket,
);
free(fiemap);
#else
CLIENT_ERROR_ON_FAILURE(
readloop(
client->socket,
client->mapped + be64toh(request.from),
be32toh(request.len)
),
"read failed from=%ld, len=%d",
be64toh(request.from),
be32toh(request.len)
);
#endif
write(client->socket, &reply, sizeof(reply));
break;
}
return 0;
}
void client_open_file(struct client_params* client)
{
client->fileno = open(client->filename, O_RDWR|O_DIRECT|O_SYNC);
CLIENT_ERROR_ON_FAILURE(client->fileno, "Couldn't open %s",
client->filename);
client->size = lseek64(client->fileno, 0, SEEK_END);
CLIENT_ERROR_ON_FAILURE(client->fileno, "Couldn't seek to end of %s",
client->filename);
client->mapped = mmap64(NULL, client->size, PROT_READ|PROT_WRITE,
MAP_SHARED, client->fileno, 0);
CLIENT_ERROR_ON_FAILURE((long) client->mapped, "Couldn't map file %s",
client->filename);
debug("opened %s size %ld on fd %d @ %p", client->filename, client->size, client->fileno, client->mapped);
}
void client_send_hello(struct client_params* client)
{
struct nbd_init init;
memcpy(init.passwd, INIT_PASSWD, sizeof(INIT_PASSWD));
init.magic = htobe64(INIT_MAGIC);
init.size = htobe64(client->size);
memset(init.reserved, 0, 128);
CLIENT_ERROR_ON_FAILURE(
writeloop(client->socket, &init, sizeof(init)),
"Couldn't send hello"
);
}
void* client_serve(void* client_uncast)
{
struct client_params* client = (struct client_params*) client_uncast;
client_open_file(client);
client_send_hello(client);
while (client_serve_request(client) == 0)
;
CLIENT_ERROR_ON_FAILURE(
close(client->socket),
"Couldn't close socket %d",
client->socket
);
free(client);
return NULL;
}
static int testmasks[9] = { 0,128,192,224,240,248,252,254,255 };
int is_included_in_acl(int list_length, struct ip_and_mask** list, struct sockaddr* test)
{
int i;
for (i=0; i < list_length; i++) {
struct ip_and_mask *entry = list[i];
int testbits;
char *raw_address1, *raw_address2;
debug("checking acl entry %d", i);
if (test->sa_family != entry->ip.family)
continue;
if (test->sa_family == AF_INET) {
raw_address1 = (char*)
&((struct sockaddr_in*) test)->sin_addr;
raw_address2 = (char*) &entry->ip.v4.sin_addr;
}
else if (test->sa_family == AF_INET6) {
raw_address1 = (char*)
&((struct sockaddr_in6*) test)->sin6_addr;
raw_address2 = (char*) &entry->ip.v6.sin6_addr;
}
for (testbits = entry->mask; testbits > 0; testbits -= 8) {
debug("testbits=%d, c1=%d, c2=%d", testbits, raw_address1[0], raw_address2[0]);
if (testbits >= 8) {
if (raw_address1[0] != raw_address2[0])
goto no_match;
}
else {
if ((raw_address1[0] & testmasks[testbits%8]) !=
(raw_address2[0] & testmasks[testbits%8]) )
goto no_match;
}
raw_address1++;
raw_address2++;
}
return 1;
no_match: ;
debug("no match");
}
return 0;
}
void serve_open_socket(struct mode_serve_params* params)
{
params->server = socket(PF_INET, SOCK_STREAM, 0);
SERVER_ERROR_ON_FAILURE(params->server,
"Couldn't create server socket");
SERVER_ERROR_ON_FAILURE(
bind(params->server, &params->bind_to.generic,
sizeof(params->bind_to.generic)),
"Couldn't bind server to IP address"
);
SERVER_ERROR_ON_FAILURE(
listen(params->server, params->tcp_backlog),
"Couldn't listen on server socket"
);
}
void serve_accept_loop(struct mode_serve_params* params)
{
while (1) {
pthread_t client_thread;
struct sockaddr client_address;
struct client_params* client_params;
socklen_t socket_length=0;
int client_socket = accept(params->server, &client_address,
&socket_length);
SERVER_ERROR_ON_FAILURE(client_socket, "accept() failed");
if (params->acl &&
!is_included_in_acl(params->acl_entries, params->acl, &client_address)) {
write(client_socket, "Access control error", 20);
close(client_socket);
continue;
}
client_params = xmalloc(sizeof(struct client_params));
client_params->socket = client_socket;
client_params->filename = params->filename;
client_params->block_allocation_map =
params->block_allocation_map;
client_params->block_allocation_map_lock =
params->block_allocation_map_lock;
client_thread = pthread_create(&client_thread, NULL,
client_serve, client_params);
SERVER_ERROR_ON_FAILURE(client_thread,
"Failed to create client thread");
/* FIXME: keep track of them? */
/* FIXME: maybe shouldn't be fatal? */
}
}
void do_serve(struct mode_serve_params* params)
{
serve_open_socket(params);
serve_accept_loop(params);
}

58
util.c Normal file
View File

@@ -0,0 +1,58 @@
#include <stdarg.h>
#include <stdio.h>
#include <pthread.h>
#include <stdlib.h>
#include <string.h>
#include <errno.h>
#include <malloc.h>
#include <unistd.h>
#include "util.h"
static pthread_t main_thread;
void error_init()
{
main_thread = pthread_self();
}
void error(int consult_errno, int close_socket, const char* format, ...)
{
va_list argptr;
fprintf(stderr, "*** ");
va_start(argptr, format);
vfprintf(stderr, format, argptr);
va_end(argptr);
if (consult_errno) {
fprintf(stderr, " (errno=%d, %s)", errno, strerror(errno));
}
if (close_socket)
close(close_socket);
fprintf(stderr, "\n");
if (pthread_equal(pthread_self(), main_thread))
exit(1);
else
pthread_exit((void*) 1);
}
void* xrealloc(void* ptr, size_t size)
{
void* p = realloc(ptr, size);
if (p == NULL)
SERVER_ERROR("couldn't xrealloc %d bytes", size);
return p;
}
void* xmalloc(size_t size)
{
void* p = xrealloc(NULL, size);
memset(p, 0, size);
return p;
}

33
util.h Normal file
View File

@@ -0,0 +1,33 @@
#ifndef __UTIL_H
#define __UTIL_H
#include <stdio.h>
#include <pthread.h>
void error_init();
void error(int consult_errno, int close_socket, const char* format, ...);
void* xrealloc(void* ptr, size_t size);
void* xmalloc(size_t size);
#ifndef DEBUG
# define debug(msg, ...)
#else
# include <sys/times.h>
# define debug(msg, ...) fprintf(stderr, "%08x %4d: " msg "\n" , \
(int) pthread_self(), (int) clock(), ##__VA_ARGS__)
#endif
#define CLIENT_ERROR(msg, ...) \
error(0, client->socket, msg, ##__VA_ARGS__)
#define CLIENT_ERROR_ON_FAILURE(test, msg, ...) \
if (test < 0) { error(1, client->socket, msg, ##__VA_ARGS__); }
#define SERVER_ERROR(msg, ...) \
error(0, 0, msg, ##__VA_ARGS__)
#define SERVER_ERROR_ON_FAILURE(test, msg, ...) \
if (test < 0) { error(1, 0, msg, ##__VA_ARGS__); }
#endif