ia64/xen-unstable

view tools/vnet/vnet-module/esp.c @ 8740:3d7ea7972b39

Update patches for linux 2.6.15.

Signed-off-by: Christian Limpach <Christian.Limpach@cl.cam.ac.uk>
author cl349@firebug.cl.cam.ac.uk
date Thu Feb 02 17:16:00 2006 +0000 (2006-02-02)
parents 0a4b76b6b5a0
children 71b0f00f6344
line source
1 /*
2 * Copyright (C) 2004 Mike Wray <mike.wray@hp.com>
3 *
4 * This program is free software; you can redistribute it and/or modify
5 * it under the terms of the GNU General Public License as published by the
6 * Free Software Foundation; either version 2 of the License, or (at your
7 * option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful, but
10 * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
11 * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
12 * for more details.
13 *
14 * You should have received a copy of the GNU General Public License along
15 * with this program; if not, write to the Free software Foundation, Inc.,
16 * 59 Temple Place, suite 330, Boston, MA 02111-1307 USA
17 *
18 */
19 #include <linux/config.h>
20 #include <linux/module.h>
21 #include <linux/types.h>
22 #include <linux/sched.h>
23 #include <linux/kernel.h>
24 #include <asm/uaccess.h>
26 #include <linux/init.h>
28 #include <linux/version.h>
30 #include <linux/skbuff.h>
31 #include <linux/netdevice.h>
32 #include <linux/net.h>
33 #include <linux/in.h>
34 #include <linux/inet.h>
36 #include <net/ip.h>
37 #include <net/protocol.h>
38 #include <net/route.h>
40 #include <linux/if_ether.h>
41 #include <linux/icmp.h>
43 #include <asm/scatterlist.h>
44 #include <linux/crypto.h>
45 #include <linux/pfkeyv2.h>
46 #include <linux/random.h>
48 #include <esp.h>
49 #include <sa.h>
50 #include <sa_algorithm.h>
51 #include <tunnel.h>
52 #include <vnet.h>
53 #include <skb_util.h>
55 static const int DEBUG_ICV = 0;
57 #define MODULE_NAME "IPSEC"
58 #define DEBUG 1
59 #undef DEBUG
60 #include "debug.h"
62 /* Outgoing packet: [ eth | ip | data ]
63 * After etherip: [ eth2 | ip2 | ethip | eth | ip | data ]
64 * After esp : [ eth2 | ip2 | esp | {ethip | eth | ip | data} | pad | icv ]
65 * ^ +
66 * The curly braces { ... } denote encryption.
67 * The esp header includes the fixed esp headers and the iv (variable size).
68 * The point marked ^ does not move. To the left is in the header, to the right
69 * is in the frag. Remember that all outgoing skbs (from domains) have 1 frag.
70 * Data after + is added by esp, using an extra frag.
71 *
72 * Incoming as above.
73 * After decrypt: [ eth2 | ip2 | esp | ethip | eth | ip | data | pad | icv ]
74 * Trim tail: [ eth2 | ip2 | esp | ethip | eth | ip | data ]
75 * Drop hdr: [ eth2 | ip2 | ethip | eth | ip | data ]
76 * ^
77 * The point marked ^ does not move. Incoming skbs are linear (no frags).
78 * The tail is trimmed by adjusting skb->tail and len.
79 * The esp hdr is dropped by using memmove to move the headers and
80 * adjusting the skb pointers.
81 *
82 * todo: Now this code is in linux we can't assume 1 frag for outbound skbs,
83 * or (maybe) that memmove is safe on inbound.
84 */
86 /** Round n up to a multiple of block.
87 * If block is less than 2 does nothing.
88 * Otherwise assume block is a power of 2.
89 *
90 * @param n to round up
91 * @param block size to round to a multiple of
92 * @return rounded value
93 */
94 static inline int roundup(int n, int block){
95 if(block <= 1) return n;
96 block--;
97 return (n + block) & ~block;
98 }
100 /** Check if n is a multiple of block.
101 * If block is less than 2 returns 1.
102 * Otherwise assumes block is a power of 2.
103 *
104 * @param n to check
105 * @param block block size
106 * @return 1 if a multiple, 0 otherwise
107 */
108 static inline int multipleof(int n, int block){
109 if(block <= 1) return 1;
110 block--;
111 return !(n & block);
112 }
114 /** Convert from bits to bytes.
115 *
116 * @param n number of bits
117 * @return number of bytes
118 */
119 static inline int bits_to_bytes(int n){
120 return n / 8;
121 }
124 /** Insert esp padding at the end of an skb.
125 * Inserts padding bytes, number of padding bytes, protocol number.
126 *
127 * @param skb skb
128 * @param offset offset from skb end to where padding should end
129 * @param extra_n total amount of padding
130 * @param protocol protocol number (from original ip hdr)
131 * @return 0 on success, error code otherwise
132 */
133 static int esp_sa_pad(struct sk_buff *skb, int offset, int extra_n,
134 unsigned char protocol){
135 int err;
136 char *data;
137 int pad_n = extra_n - ESP_PAD_N;
138 int i;
139 char buf[extra_n];
141 data = buf;
142 for(i = 1; i <= pad_n; i++){
143 *data++ = i;
144 }
145 *data++ = pad_n;
146 *data++ = protocol;
147 err = skb_put_bits(skb, skb->len - offset - extra_n, buf, extra_n);
148 return err;
149 }
151 /** Encrypt skb. Skips esp header and iv.
152 * Assumes skb->data points at esp header.
153 *
154 * @param esp esp state
155 * @parm esph esp header
156 * @param skb packet
157 * @param head_n size of esp header and iv
158 * @param iv_n size of iv
159 * @param text_n size of ciphertext
160 * @return 0 on success, error code otherwise
161 */
162 static int esp_sa_encrypt(ESPState *esp, ESPHdr *esph, struct sk_buff *skb,
163 int head_n, int iv_n, int text_n){
164 int err = 0;
165 int sg_n = skb_shinfo(skb)->nr_frags + 1;
166 struct scatterlist sg[sg_n];
168 err = skb_scatterlist(skb, sg, &sg_n, head_n, text_n);
169 if(err) goto exit;
170 if(iv_n){
171 crypto_cipher_set_iv(esp->cipher.tfm, esp->cipher.iv, iv_n);
172 }
173 crypto_cipher_encrypt(esp->cipher.tfm, sg, sg, text_n);
174 if(iv_n){
175 memcpy(esph->data, esp->cipher.iv, iv_n);
176 crypto_cipher_get_iv(esp->cipher.tfm, esp->cipher.iv, iv_n);
177 }
178 exit:
179 return err;
180 }
182 /** Decrypt skb. Skips esp header and iv.
183 * Assumes skb->data points at esp header.
184 *
185 * @param esp esp state
186 * @parm esph esp header
187 * @param skb packet
188 * @param head_n size of esp header and iv
189 * @param iv_n size of iv
190 * @param text_n size of ciphertext
191 * @return 0 on success, error code otherwise
192 */
193 static int esp_sa_decrypt(ESPState *esp, ESPHdr *esph, struct sk_buff *skb,
194 int head_n, int iv_n, int text_n){
195 int err = 0;
196 int sg_n = skb_shinfo(skb)->nr_frags + 1;
197 struct scatterlist sg[sg_n];
199 err = skb_scatterlist(skb, sg, &sg_n, head_n, text_n);
200 if(err) goto exit;
201 if(iv_n){
202 crypto_cipher_set_iv(esp->cipher.tfm, esph->data, iv_n);
203 }
204 crypto_cipher_decrypt(esp->cipher.tfm, sg, sg, text_n);
205 exit:
206 return err;
207 }
209 /** Compute icv. Includes esp header, iv and ciphertext.
210 * Assumes skb->data points at esp header.
211 *
212 * @param esp esp state
213 * @param skb packet
214 * @param digest_n number of bytes to digest
215 * @param icv_n size of icv
216 * @return 0 on success, error code otherwise
217 */
218 static int esp_sa_digest(ESPState *esp, struct sk_buff *skb, int digest_n, int icv_n){
219 int err = 0;
220 u8 icv[icv_n];
222 if(DEBUG_ICV){
223 dprintf("> skb digest_n=%d icv_n=%d\n", digest_n, icv_n);
224 skb_print_bits(skb, 0, digest_n);
225 }
226 memset(icv, 0, icv_n);
227 esp->digest.icv(esp, skb, 0, digest_n, icv);
228 skb_put_bits(skb, digest_n, icv, icv_n);
229 return err;
230 }
232 /** Check the icv and trim it from the skb tail.
233 *
234 * @param sa sa state
235 * @param esp esp state
236 * @param esph esp header
237 * @param skb packet
238 * @return 0 on success, error code otherwise
239 */
240 static int esp_check_icv(SAState *sa, ESPState *esp, ESPHdr *esph, struct sk_buff *skb){
241 int err = 0;
242 int icv_n = esp->digest.icv_n;
243 int digest_n = skb->len - icv_n;
244 u8 icv_skb[icv_n];
245 u8 icv_new[icv_n];
247 dprintf(">\n");
248 if(DEBUG_ICV){
249 dprintf("> skb len=%d digest_n=%d icv_n=%d\n",
250 skb->len, digest_n, icv_n);
251 skb_print_bits(skb, 0, skb->len);
252 }
253 if(skb_copy_bits(skb, digest_n, icv_skb, icv_n)){
254 wprintf("> Error getting icv from skb\n");
255 goto exit;
256 }
257 esp->digest.icv(esp, skb, 0, digest_n, icv_new);
258 if(DEBUG_ICV){
259 dprintf("> len=%d icv_n=%d", digest_n, icv_n);
260 printk("\nskb="); buf_print(icv_skb, icv_n);
261 printk("new="); buf_print(icv_new, icv_n);
262 }
263 if(unlikely(memcmp(icv_new, icv_skb, icv_n))){
264 wprintf("> ICV check failed!\n");
265 err = -EINVAL;
266 sa->counts.integrity_failures++;
267 goto exit;
268 }
269 skb_trim_tail(skb, icv_n);
270 exit:
271 dprintf("< err=%d\n", err);
272 return err;
273 }
275 /** Send a packet via an ESP SA.
276 *
277 * @param sa SA state
278 * @param skb packet to send
279 * @param tunnel underlying tunnel
280 * @return 0 on success, negative error code otherwise
281 */
282 static int esp_sa_send(SAState *sa, struct sk_buff *skb, Tunnel *tunnel){
283 int err = 0;
284 int ip_n; // Size of ip header.
285 int plaintext_n; // Size of plaintext.
286 int ciphertext_n; // Size of ciphertext (including padding).
287 int extra_n; // Extra bytes needed for ciphertext.
288 int icv_n = 0; // Size of integrity check value (icv).
289 int iv_n = 0; // Size of initialization vector (iv).
290 int head_n; // Size of esp header and iv.
291 int tail_n; // Size of esp trailer: padding and icv.
292 ESPState *esp;
293 ESPHdr *esph;
295 dprintf(">\n");
296 esp = sa->data;
297 ip_n = (skb->nh.iph->ihl << 2);
298 // Assuming skb->data points at ethernet header, exclude ethernet
299 // header and IP header.
300 plaintext_n = skb->len - ETH_HLEN - ip_n;
301 // Add size of padding fields.
302 ciphertext_n = roundup(plaintext_n + ESP_PAD_N, esp->cipher.block_n);
303 if(esp->cipher.pad_n > 0){
304 ciphertext_n = roundup(ciphertext_n, esp->cipher.pad_n);
305 }
306 extra_n = ciphertext_n - plaintext_n;
307 iv_n = esp->cipher.iv_n;
308 icv_n = esp->digest.icv_n;
309 dprintf("> len=%d plaintext=%d ciphertext=%d extra=%d\n",
310 skb->len, plaintext_n, ciphertext_n, extra_n);
311 dprintf("> iv=%d icv=%d\n", iv_n, icv_n);
312 skb_print_bits(skb, 0, skb->len);
314 // Add headroom for esp header and iv, tailroom for the ciphertext
315 // and icv.
316 head_n = ESP_HDR_N + iv_n;
317 tail_n = extra_n + icv_n;
318 err = skb_make_room(&skb, skb, head_n, tail_n);
319 if(err) goto exit;
320 dprintf("> skb=%p\n", skb);
321 // Move the headers up to make space for the esp header. We can
322 // use memmove() since all this data fits in the skb head.
323 // todo: Can't assume this anymore?
324 dprintf("> header push...\n");
325 __skb_push(skb, head_n);
326 if(0 && skb->mac.raw){
327 dprintf("> skb->mac=%p\n", skb->mac.raw);
328 dprintf("> ETH header pull...\n");
329 memmove(skb->data, skb->mac.raw, ETH_HLEN);
330 skb->mac.raw = skb->data;
331 __skb_pull(skb, ETH_HLEN);
332 }
333 dprintf("> IP header pull...\n");
334 memmove(skb->data, skb->nh.raw, ip_n);
335 skb->nh.raw = skb->data;
336 __skb_pull(skb, ip_n);
337 esph = (void*)skb->data;
338 // Add spi and sequence number.
339 esph->spi = sa->ident.spi;
340 esph->seq = htonl(++sa->replay.send_seq);
341 // Insert the padding bytes: extra bytes less the pad fields
342 // themselves.
343 dprintf("> esp_sa_pad ...\n");
344 esp_sa_pad(skb, icv_n, extra_n, skb->nh.iph->protocol);
345 if(sa->security & SA_CONF){
346 dprintf("> esp_sa_encrypt...\n");
347 err = esp_sa_encrypt(esp, esph, skb, head_n, iv_n, ciphertext_n);
348 if(err) goto exit;
349 }
350 if(icv_n){
351 dprintf("> esp_sa_digest...\n");
352 err = esp_sa_digest(esp, skb, head_n + ciphertext_n, icv_n);
353 if(err) goto exit;
354 }
355 dprintf("> IP header push...\n");
356 __skb_push(skb, ip_n);
357 if(0 && skb->mac.raw){
358 dprintf("> ETH header push...\n");
359 __skb_push(skb, ETH_HLEN);
360 }
361 // Fix ip header. Adjust length field, set protocol, zero
362 // checksum.
363 {
364 // Total packet length (bytes).
365 int tot_len = ntohs(skb->nh.iph->tot_len);
366 tot_len += head_n;
367 tot_len += tail_n;
368 skb->nh.iph->protocol = IPPROTO_ESP;
369 skb->nh.iph->tot_len = htons(tot_len);
370 skb->nh.iph->check = 0;
371 }
372 err = Tunnel_send(tunnel, skb);
373 exit:
374 dprintf("< err=%d\n", err);
375 return err;
376 }
378 /** Release an skb context.
379 * Drops the refcount on the SA.
380 *
381 * @param context to free
382 */
383 static void esp_context_free_fn(SkbContext *context){
384 SAState *sa;
385 if(!context) return;
386 sa = context->data;
387 if(!sa) return;
388 context->data = NULL;
389 SAState_decref(sa);
390 }
392 /** Receive a packet via an ESP SA.
393 * Does ESP receive processing (check icv, decrypt), strips
394 * ESP header and re-receives.
395 *
396 * @param sa SA
397 * @param skb packet
398 * @return 0 on success, negative error code otherwise
399 */
400 static int esp_sa_recv(SAState *sa, struct sk_buff *skb){
401 int err = -EINVAL;
402 int mine = 0;
403 int vnet = 0; //todo: fixme - need to record skb vnet somewhere
404 ESPState *esp;
405 ESPHdr *esph;
406 ESPPadding *pad;
407 int block_n; // Cipher blocksize.
408 int icv_n; // Size of integrity check value (icv).
409 int iv_n; // Size of initialization vector (iv).
410 int text_n; // Size of text (ciphertext or plaintext).
411 int head_n; // Size of esp header and iv.
413 dprintf("> skb=%p\n", skb);
414 // Assumes skb->data points at esp hdr.
415 esph = (void*)skb->data;
416 esp = sa->data;
417 block_n = crypto_tfm_alg_blocksize(esp->cipher.tfm);
418 icv_n = esp->digest.icv_n;
419 iv_n = esp->cipher.iv_n;
420 head_n = ESP_HDR_N + iv_n;
421 text_n = skb->len - head_n - icv_n;
422 if(text_n < ESP_PAD_N || !multipleof(text_n, block_n)){
423 wprintf("> Invalid size: text_n=%d tfm:block_n=%d esp:block_n=%d\n",
424 text_n, block_n, esp->cipher.block_n);
425 goto exit;
426 }
427 if(icv_n){
428 err = esp_check_icv(sa, esp, esph, skb);
429 if(err) goto exit;
430 }
431 mine = 1;
432 if(sa->security & SA_CONF){
433 err = esp_sa_decrypt(esp, esph, skb, head_n, iv_n, text_n);
434 if(err) goto exit;
435 }
436 // Strip esp header by moving the other headers down.
437 //todo Maybe not safe to do this anymore.
438 memmove(skb->mac.raw + head_n, skb->mac.raw, (skb->data - skb->mac.raw));
439 skb->mac.raw += head_n;
440 skb->nh.raw += head_n;
441 // Move skb->data back to ethernet header.
442 // Do in 2 moves to ensure offsets are +ve,
443 // since args to skb_pull/skb_push are unsigned.
444 __skb_pull(skb, head_n);
445 __skb_push(skb, skb->data - skb->mac.raw);
446 // After this esph is invalid.
447 esph = NULL;
448 // Trim padding, restore protocol in IP header.
449 pad = skb_trim_tail(skb, ESP_PAD_N);
450 text_n -= ESP_PAD_N;
451 if((pad->pad_n > 255) | (pad->pad_n > text_n)){
452 wprintf("> Invalid padding: pad_n=%d text_n=%d\n", pad->pad_n, text_n);
453 goto exit;
454 }
455 skb_trim_tail(skb, pad->pad_n);
456 skb->nh.iph->protocol = pad->protocol;
457 err = skb_push_context(skb, vnet, sa->ident.addr, IPPROTO_ESP,
458 sa, esp_context_free_fn);
459 if(err) goto exit;
460 // Increase sa refcount now the skb context refers to it.
461 SAState_incref(sa);
462 err = netif_rx(skb);
463 exit:
464 if(mine) err = 1;
465 dprintf("< skb=%p err=%d\n", skb, err);
466 return err;
467 }
469 /** Estimate the packet size for some data using ESP processing.
470 *
471 * @param sa ESP SA
472 * @param data_n data size
473 * @return size after ESP processing
474 */
475 static u32 esp_sa_size(SAState *sa, int data_n){
476 // Even in transport mode have to round up to blocksize.
477 // Have to add some padding for alignment even if pad_n is zero.
478 ESPState *esp = sa->data;
480 data_n = roundup(data_n + ESP_PAD_N, esp->cipher.block_n);
481 if(esp->cipher.pad_n > 0){
482 data_n = roundup(data_n, esp->cipher.pad_n);
483 }
484 data_n += esp->digest.icv_n;
485 //data_n += esp->cipher.iv_n;
486 data_n += ESP_HDR_N;
487 return data_n;
488 }
490 /** Compute an icv using HMAC digest.
491 *
492 * @param esp ESP state
493 * @param skb packet to digest
494 * @param offset offset to start at
495 * @param len number of bytes to digest
496 * @param icv return parameter for ICV
497 * @return 0 on success, negative error code otherwise
498 */
499 static inline void esp_hmac_digest(ESPState *esp, struct sk_buff *skb,
500 int offset, int len, u8 *icv){
501 int err = 0;
502 struct crypto_tfm *digest = esp->digest.tfm;
503 char *icv_tmp = esp->digest.icv_tmp;
504 int sg_n = skb_shinfo(skb)->nr_frags + 1;
505 struct scatterlist sg[sg_n];
507 dprintf("> offset=%d len=%d\n", offset, len);
508 memset(icv, 0, esp->digest.icv_n);
509 if(DEBUG_ICV){
510 dprintf("> key len=%d\n", esp->digest.key_n);
511 printk("\nkey=");
512 buf_print(esp->digest.key,esp->digest.key_n);
513 }
514 crypto_hmac_init(digest, esp->digest.key, &esp->digest.key_n);
515 err = skb_scatterlist(skb, sg, &sg_n, offset, len);
516 crypto_hmac_update(digest, sg, sg_n);
517 crypto_hmac_final(digest, esp->digest.key, &esp->digest.key_n, icv_tmp);
518 if(DEBUG_ICV){
519 dprintf("> digest len=%d ", esp->digest.icv_n);
520 printk("\nval=");
521 buf_print(icv_tmp, esp->digest.icv_n);
522 }
523 memcpy(icv, icv_tmp, esp->digest.icv_n);
524 dprintf("<\n");
525 }
527 /** Finish up an esp state.
528 * Releases the digest, cipher, iv and frees the state.
529 *
530 * @parma esp state
531 */
532 static void esp_fini(ESPState *esp){
533 if(!esp) return;
534 if(esp->digest.tfm){
535 crypto_free_tfm(esp->digest.tfm);
536 esp->digest.tfm = NULL;
537 }
538 if(esp->digest.icv_tmp){
539 kfree(esp->digest.icv_tmp);
540 esp->digest.icv_tmp = NULL;
541 }
542 if(esp->cipher.tfm){
543 crypto_free_tfm(esp->cipher.tfm);
544 esp->cipher.tfm = NULL;
545 }
546 if(esp->cipher.iv){
547 kfree(esp->cipher.iv);
548 esp->cipher.iv = NULL;
549 }
550 kfree(esp);
551 }
553 /** Release an ESP SA.
554 *
555 * @param sa ESO SA
556 */
557 static void esp_sa_fini(SAState *sa){
558 ESPState *esp;
559 if(!sa) return;
560 esp = sa->data;
561 if(!esp) return;
562 esp_fini(esp);
563 sa->data = NULL;
564 }
566 /** Initialize the cipher for an ESP SA.
567 *
568 * @param sa ESP SA
569 * @param esp ESP state
570 * @return 0 on success, negative error code otherwise
571 */
572 static int esp_cipher_init(SAState *sa, ESPState *esp){
573 int err = 0;
574 SAAlgorithm *algo = NULL;
575 int cipher_mode = CRYPTO_TFM_MODE_CBC;
577 dprintf("> sa=%p esp=%p\n", sa, esp);
578 dprintf("> cipher=%s\n", sa->cipher.name);
579 algo = sa_cipher_by_name(sa->cipher.name);
580 if(!algo){
581 wprintf("> Cipher unavailable: %s\n", sa->cipher.name);
582 err = -EINVAL;
583 goto exit;
584 }
585 esp->cipher.key_n = roundup(sa->cipher.bits, 8);
586 // If cipher is null must use ECB because CBC algo does not support blocksize 1.
587 if(strcmp(sa->cipher.name, "cipher_null")){
588 cipher_mode = CRYPTO_TFM_MODE_ECB;
589 }
590 esp->cipher.tfm = crypto_alloc_tfm(sa->cipher.name, cipher_mode);
591 if(!esp->cipher.tfm){
592 err = -ENOMEM;
593 goto exit;
594 }
595 esp->cipher.block_n = roundup(crypto_tfm_alg_blocksize(esp->cipher.tfm), 4);
596 esp->cipher.iv_n = crypto_tfm_alg_ivsize(esp->cipher.tfm);
597 esp->cipher.pad_n = 0;
598 if(esp->cipher.iv_n){
599 esp->cipher.iv = kmalloc(esp->cipher.iv_n, GFP_KERNEL);
600 get_random_bytes(esp->cipher.iv, esp->cipher.iv_n);
601 }
602 crypto_cipher_setkey(esp->cipher.tfm, esp->cipher.key, esp->cipher.key_n);
603 err = 0;
604 exit:
605 dprintf("< err=%d\n", err);
606 return err;
607 }
609 /** Initialize the digest for an ESP SA.
610 *
611 * @param sa ESP SA
612 * @param esp ESP state
613 * @return 0 on success, negative error code otherwise
614 */
615 static int esp_digest_init(SAState *sa, ESPState *esp){
616 int err = 0;
617 SAAlgorithm *algo = NULL;
619 dprintf(">\n");
620 esp->digest.key = sa->digest.key;
621 esp->digest.key_n = bits_to_bytes(roundup(sa->digest.bits, 8));
622 esp->digest.tfm = crypto_alloc_tfm(sa->digest.name, 0);
623 if(!esp->digest.tfm){
624 err = -ENOMEM;
625 goto exit;
626 }
627 algo = sa_digest_by_name(sa->digest.name);
628 if(!algo){
629 wprintf("> Digest unavailable: %s\n", sa->digest.name);
630 err = -EINVAL;
631 goto exit;
632 }
633 esp->digest.icv = esp_hmac_digest;
634 esp->digest.icv_full_n = bits_to_bytes(algo->info.digest.icv_fullbits);
635 esp->digest.icv_n = bits_to_bytes(algo->info.digest.icv_truncbits);
637 if(esp->digest.icv_full_n != crypto_tfm_alg_digestsize(esp->digest.tfm)){
638 err = -EINVAL;
639 wprintf("> digest %s, size %u != %hu\n",
640 sa->digest.name,
641 crypto_tfm_alg_digestsize(esp->digest.tfm),
642 esp->digest.icv_full_n);
643 goto exit;
644 }
646 esp->digest.icv_tmp = kmalloc(esp->digest.icv_full_n, GFP_KERNEL);
647 if(!esp->digest.icv_tmp){
648 err = -ENOMEM;
649 goto exit;
650 }
651 exit:
652 dprintf("< err=%d\n", err);
653 return err;
654 }
656 /** Initialize an ESP SA.
657 *
658 * @param sa ESP SA
659 * @param args arguments
660 * @return 0 on success, negative error code otherwise
661 */
662 static int esp_sa_init(SAState *sa, void *args){
663 int err = 0;
664 ESPState *esp = NULL;
666 dprintf("> sa=%p\n", sa);
667 esp = kmalloc(sizeof(*esp), GFP_KERNEL);
668 if(!esp){
669 err = -ENOMEM;
670 goto exit;
671 }
672 *esp = (ESPState){};
673 err = esp_cipher_init(sa, esp);
674 if(err) goto exit;
675 err = esp_digest_init(sa, esp);
676 if(err) goto exit;
677 sa->data = esp;
678 exit:
679 if(err){
680 if(esp) esp_fini(esp);
681 }
682 dprintf("< err=%d\n", err);
683 return err;
684 }
686 /** SA type for ESP.
687 */
688 static SAType esp_sa_type = {
689 .name = "ESP",
690 .protocol = IPPROTO_ESP,
691 .init = esp_sa_init,
692 .fini = esp_sa_fini,
693 .size = esp_sa_size,
694 .recv = esp_sa_recv,
695 .send = esp_sa_send
696 };
698 /** Get the ESP header from a packet.
699 *
700 * @param skb packet
701 * @param esph return parameter for header
702 * @return 0 on success, negative error code otherwise
703 */
704 static int esp_skb_header(struct sk_buff *skb, ESPHdr **esph){
705 int err = 0;
706 if(skb->len < ESP_HDR_N){
707 err = -EINVAL;
708 goto exit;
709 }
710 *esph = (ESPHdr*)skb->data;
711 exit:
712 return err;
713 }
715 /** Handle an incoming skb with ESP protocol.
716 *
717 * Lookup spi, if state found hand to the state.
718 * If no state, check spi, if ok, create state and pass to it.
719 * If spi not ok, drop.
720 *
721 * @param skb packet
722 * @return 0 on sucess, negative error code otherwise
723 */
724 static int esp_protocol_recv(struct sk_buff *skb){
725 int err = 0;
726 const int eth_n = ETH_HLEN;
727 int ip_n;
728 ESPHdr *esph = NULL;
729 SAState *sa = NULL;
730 u32 addr;
732 dprintf(">\n");
733 dprintf("> recv skb=\n"); skb_print_bits(skb, 0, skb->len);
734 ip_n = (skb->nh.iph->ihl << 2);
735 if(skb->data == skb->mac.raw){
736 // skb->data points at ethernet header.
737 if (!pskb_may_pull(skb, eth_n + ip_n)){
738 wprintf("> Malformed skb\n");
739 err = -EINVAL;
740 goto exit;
741 }
742 skb_pull(skb, eth_n + ip_n);
743 }
744 addr = skb->nh.iph->daddr;
745 err = esp_skb_header(skb, &esph);
746 if(err) goto exit;
747 dprintf("> spi=%08x protocol=%d addr=" IPFMT "\n",
748 esph->spi, IPPROTO_ESP, NIPQUAD(addr));
749 sa = sa_table_lookup_spi(esph->spi, IPPROTO_ESP, addr);
750 if(!sa){
751 err = vnet_sa_create(esph->spi, IPPROTO_ESP, addr, &sa);
752 if(err) goto exit;
753 }
754 err = SAState_recv(sa, skb);
755 exit:
756 if(sa) SAState_decref(sa);
757 dprintf("< err=%d\n", err);
758 return err;
759 }
761 /** Handle an ICMP error related to ESP.
762 *
763 * @param skb ICMP error packet
764 * @param info
765 */
766 static void esp_protocol_icmp_err(struct sk_buff *skb, u32 info){
767 struct iphdr *iph = (struct iphdr*)skb->data;
768 ESPHdr *esph;
769 SAState *sa;
771 dprintf("> ICMP error type=%d code=%d\n",
772 skb->h.icmph->type, skb->h.icmph->code);
773 if(skb->h.icmph->type != ICMP_DEST_UNREACH ||
774 skb->h.icmph->code != ICMP_FRAG_NEEDED){
775 return;
776 }
778 //todo: need to check skb has enough len to do this.
779 esph = (ESPHdr*)(skb->data + (iph->ihl << 2));
780 sa = sa_table_lookup_spi(esph->spi, IPPROTO_ESP, iph->daddr);
781 if(!sa) return;
782 wprintf("> ICMP unreachable on SA ESP spi=%08x addr=" IPFMT "\n",
783 ntohl(esph->spi), NIPQUAD(iph->daddr));
784 SAState_decref(sa);
785 }
787 //============================================================================
788 #if LINUX_VERSION_CODE >= KERNEL_VERSION(2,6,0)
789 // Code for 2.6 kernel.
791 /** Protocol handler for ESP.
792 */
793 static struct net_protocol esp_protocol = {
794 .handler = esp_protocol_recv,
795 .err_handler = esp_protocol_icmp_err
796 };
798 static int esp_protocol_add(void){
799 return inet_add_protocol(&esp_protocol, IPPROTO_ESP);
800 }
802 static int esp_protocol_del(void){
803 return inet_del_protocol(&esp_protocol, IPPROTO_ESP);
804 }
806 //============================================================================
807 #else
808 //============================================================================
809 // Code for 2.4 kernel.
811 /** Protocol handler for ESP.
812 */
813 static struct inet_protocol esp_protocol = {
814 .name = "ESP",
815 .protocol = IPPROTO_ESP,
816 .handler = esp_protocol_recv,
817 .err_handler = esp_protocol_icmp_err
818 };
820 static int esp_protocol_add(void){
821 inet_add_protocol(&esp_protocol);
822 return 0;
823 }
825 static int esp_protocol_del(void){
826 return inet_del_protocol(&esp_protocol);
827 }
829 #endif
830 //============================================================================
833 /** Initialize the ESP module.
834 * Registers the ESP protocol and SA type.
835 *
836 * @return 0 on success, negative error code otherwise
837 */
838 int __init esp_module_init(void){
839 int err = 0;
840 dprintf(">\n");
841 err = SAType_add(&esp_sa_type);
842 if(err < 0){
843 eprintf("> Error adding esp sa type\n");
844 goto exit;
845 }
846 esp_protocol_add();
847 exit:
848 dprintf("< err=%d\n", err);
849 return err;
850 }
852 /** Finalize the ESP module.
853 * Deregisters the ESP protocol and SA type.
854 */
855 void __exit esp_module_exit(void){
856 if(esp_protocol_del() < 0){
857 eprintf("> Error removing esp protocol\n");
858 }
859 if(SAType_del(&esp_sa_type) < 0){
860 eprintf("> Error removing esp sa type\n");
861 }
862 }