diff options
Diffstat (limited to 'net/lwip/wget.c')
-rw-r--r-- | net/lwip/wget.c | 122 |
1 files changed, 90 insertions, 32 deletions
diff --git a/net/lwip/wget.c b/net/lwip/wget.c index a3b82908877..ea1113e18b1 100644 --- a/net/lwip/wget.c +++ b/net/lwip/wget.c @@ -6,8 +6,10 @@ #include <display_options.h> #include <efi_loader.h> #include <image.h> +#include <linux/kconfig.h> #include <lwip/apps/http_client.h> #include "lwip/altcp_tls.h" +#include <lwip/errno.h> #include <lwip/timeouts.h> #include <rng.h> #include <mapmem.h> @@ -201,11 +203,58 @@ static int parse_legacy_arg(char *arg, char *nurl, size_t rem) return 0; } +/** + * store_block() - copy received data + * + * This function is called by the receive callback to copy a block of data + * into its final location (ctx->daddr). Before doing so, it checks if the copy + * is allowed. + * + * @ctx: the context for the current transfer + * @src: the data received from the TCP stack + * @len: the length of the data + */ +static int store_block(struct wget_ctx *ctx, void *src, u16_t len) +{ + ulong store_addr = ctx->daddr; + uchar *ptr; + + /* Avoid overflow */ + if (wget_info->buffer_size && wget_info->buffer_size < ctx->size + len) + return -1; + + if (CONFIG_IS_ENABLED(LMB) && wget_info->set_bootdev) { + if (store_addr + len < store_addr || + lmb_read_check(store_addr, len)) { + if (!wget_info->silent) { + printf("\nwget error: "); + printf("trying to overwrite reserved memory\n"); + } + return -1; + } + } + + ptr = map_sysmem(store_addr, len); + memcpy(ptr, src, len); + unmap_sysmem(ptr); + + ctx->daddr += len; + ctx->size += len; + if (ctx->size - ctx->prevsize > PROGRESS_PRINT_STEP_BYTES) { + if (!wget_info->silent) + printf("#"); + ctx->prevsize = ctx->size; + } + + return 0; +} + static err_t httpc_recv_cb(void *arg, struct altcp_pcb *pcb, struct pbuf *pbuf, err_t err) { struct wget_ctx *ctx = arg; struct pbuf *buf; + err_t ret; if (!pbuf) return ERR_BUF; @@ -214,18 +263,17 @@ static err_t httpc_recv_cb(void *arg, struct altcp_pcb *pcb, struct pbuf *pbuf, ctx->start_time = get_timer(0); for (buf = pbuf; buf; buf = buf->next) { - memcpy((void *)ctx->daddr, buf->payload, buf->len); - ctx->daddr += buf->len; - ctx->size += buf->len; - if (ctx->size - ctx->prevsize > PROGRESS_PRINT_STEP_BYTES) { - printf("#"); - ctx->prevsize = ctx->size; + if (store_block(ctx, buf->payload, buf->len) < 0) { + altcp_abort(pcb); + ret = ERR_BUF; + goto out; } } - altcp_recved(pcb, pbuf->tot_len); + ret = ERR_OK; +out: pbuf_free(pbuf); - return ERR_OK; + return ret; } static void httpc_result_cb(void *arg, httpc_result_t httpc_result, @@ -255,11 +303,15 @@ static void httpc_result_cb(void *arg, httpc_result_t httpc_result, elapsed = get_timer(ctx->start_time); if (!elapsed) elapsed = 1; - if (rx_content_len > PROGRESS_PRINT_STEP_BYTES) - printf("\n"); - printf("%u bytes transferred in %lu ms (", rx_content_len, elapsed); - print_size(rx_content_len / elapsed * 1000, "/s)\n"); - printf("Bytes transferred = %lu (%lx hex)\n", ctx->size, ctx->size); + if (!wget_info->silent) { + if (rx_content_len > PROGRESS_PRINT_STEP_BYTES) + printf("\n"); + printf("%u bytes transferred in %lu ms (", rx_content_len, + elapsed); + print_size(rx_content_len / elapsed * 1000, "/s)\n"); + printf("Bytes transferred = %lu (%lx hex)\n", ctx->size, + ctx->size); + } if (wget_info->set_bootdev) efi_set_bootdev("Http", ctx->server_name, ctx->path, map_sysmem(ctx->saved_daddr, 0), rx_content_len); @@ -339,7 +391,8 @@ static int _set_cacert(const void *addr, size_t sz) mbedtls_x509_crt_init(&crt); ret = mbedtls_x509_crt_parse(&crt, cacert, cacert_size); if (ret) { - printf("Could not parse certificates (%d)\n", ret); + if (!wget_info->silent) + printf("Could not parse certificates (%d)\n", ret); free(cacert); cacert = NULL; cacert_size = 0; @@ -372,13 +425,14 @@ static int set_cacert(char * const saddr, char * const ssz) #endif #endif /* CONFIG_WGET_CACERT || CONFIG_WGET_BUILTIN_CACERT */ -static int wget_loop(struct udevice *udev, ulong dst_addr, char *uri) +int wget_do_request(ulong dst_addr, char *uri) { #if CONFIG_IS_ENABLED(WGET_HTTPS) altcp_allocator_t tls_allocator; #endif httpc_connection_t conn; httpc_state_t *state; + struct udevice *udev; struct netif *netif; struct wget_ctx ctx; char *path; @@ -394,6 +448,14 @@ static int wget_loop(struct udevice *udev, ulong dst_addr, char *uri) if (parse_url(uri, ctx.server_name, &ctx.port, &path, &is_https)) return CMD_RET_USAGE; + if (net_lwip_eth_start() < 0) + return CMD_RET_FAILURE; + + if (!wget_info) + wget_info = &default_wget_info; + + udev = eth_get_dev(); + netif = net_lwip_new_netif(udev); if (!netif) return -1; @@ -413,9 +475,10 @@ static int wget_loop(struct udevice *udev, ulong dst_addr, char *uri) if (cacert_auth_mode == AUTH_REQUIRED) { if (!ca || !ca_sz) { - printf("Error: cacert authentication mode is " - "'required' but no CA certificates " - "given\n"); + if (!wget_info->silent) + printf("Error: cacert authentication " + "mode is 'required' but no CA " + "certificates given\n"); return CMD_RET_FAILURE; } } else if (cacert_auth_mode == AUTH_NONE) { @@ -430,6 +493,10 @@ static int wget_loop(struct udevice *udev, ulong dst_addr, char *uri) */ } + if (!ca && !wget_info->silent) { + printf("WARNING: no CA certificates, "); + printf("HTTPS connections not authenticated\n"); + } tls_allocator.alloc = &altcp_tls_alloc; tls_allocator.arg = altcp_tls_create_config_client(ca, ca_sz, @@ -454,6 +521,8 @@ static int wget_loop(struct udevice *udev, ulong dst_addr, char *uri) return CMD_RET_FAILURE; } + errno = 0; + while (!ctx.done) { net_lwip_rx(udev, netif); sys_check_timeouts(); @@ -466,21 +535,10 @@ static int wget_loop(struct udevice *udev, ulong dst_addr, char *uri) if (ctx.done == SUCCESS) return 0; - return -1; -} - -int wget_do_request(ulong dst_addr, char *uri) -{ - int ret; + if (errno == EPERM && !wget_info->silent) + printf("Certificate verification failed\n"); - ret = net_lwip_eth_start(); - if (ret < 0) - return ret; - - if (!wget_info) - wget_info = &default_wget_info; - - return wget_loop(eth_get_dev(), dst_addr, uri); + return -1; } int do_wget(struct cmd_tbl *cmdtp, int flag, int argc, char * const argv[]) |