AES encryption between Arduino and Golang Server

Arduino와 Golang간 AES 암호화통신 구현하기

IoT용으로 원격에서 아두이노 컨트롤 기능을 HTTPS위에 Rest API로 구현을 했었는데
아두이노를 산간오지에 둬야하다보니
데이터통신요금을 줄이고자 TCP소켓통신으로 변경.
암호화가 필요할거 같아 AES암호화를 적용하면서 좌충우돌했던 경험담을 정리합니다.

<준비물>
Wemos D1 mini
Golang돌릴수 있는 PC

<참고사이트>
아두이노쪽 : https://github.com/kakopappa/arduino-esp8266-aes-lib
golang 서버: http://pyrasis.com/book/GoForTheReallyImpatient/Unit53/02

처음엔 결과값이 달라서 라이브러리만 찾으러다니다 몇가지 사실을 알게됐음.
정리하자면 AES128 CBC를 쓰는 라이브러리인데 pkcs7으로 패딩해줍니다.
다만 암호화할 대상 문자열을 base64인코딩부터 해준다음 AES암호화 합니다.

좀 더 부연설명하자면 암호화는 아래와 같이 이뤄집니다.
암호화할 대상이 “text”, 암호화키는 아두이노와 서버간 이미 공유했다고 가정

  1. “text”를 base64로 인코딩 “dGV4dA==”을 편의상 b64_text 변수라 가정
  2. AES암호화를 위해 랜덤 IV생성 (IV가 뭔지 찾아보시면 좋을 것 같아요)
  3. b64_text를 AES 암호화 키와 IV를 이용해서 암호화 함 => encrypted_text
  4. IV와 encrypted_text 문자열을 합쳐 다시 base64로 인코딩
  5. b64_encrypted_text 전송

복호화는 당연히 반대순이겠죠? ㅎㅎ

아두이노 쪽입니다.

bool auth_server() {
  if (!is_connected()) {
    Serial.println("Client doesn't connect.");
    return false;
  }
  String json_data = "";
  
  StaticJsonDocument<200> doc;
  doc["cmd"] = "a";
  doc["id"] = ESP.getChipId();
  doc["mac"] = WiFi.macAddress();
  serializeJson(doc, json_data);
  String enc_data = encrypt(json_data);
  Serial.println("auth " + enc_data);
  client.print(enc_data);
  
  delay(3000); // wait for reply
  Serial.print("AuthReceived: ");
  if (client.available()) {
    String line = client.readStringUntil('\n');
    Serial.println(line);
    return true;
  }
  else {
    Serial.println("can't wait");
    return false;
  }
}

위 코드는 golang서버와 인증하는 부분인데
사용할 때 encrypt()이 1~5단계를 수행해줍니다.
호출부분은 encrypt(json_data) 입니다

encrypt와 decrypt는 아래에 있습니다.

String encrypt(String msg) {
   char b64data[200];
    byte cipher[500];
    byte iv [N_BLOCK] ;

    // The unitialized Initialization vector
    byte my_iv[N_BLOCK] = { 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0};
    
    // Our message to encrypt. Static for this example.
    //String msg = "{\"data\":{\"value\":300}, \"SEQN\":700 , \"msg\":\"IT WORKS!!\" }";

    aes.set_key( key , sizeof(key));  // Get the globally defined key
    gen_iv( my_iv );                  // Generate a random IV
    
    // Print the IV
    base64_encode( b64data, (char *)my_iv, N_BLOCK);

    char pre_iv[16];
    //Serial.println(sizeof(my_iv));
    //Serial.print(" iv : ");
    
    for (int i=0; i<sizeof(my_iv); i++) {
      if (my_iv[i]<16) {
        //Serial.print(0,HEX);
      }
      //Serial.print(my_iv[i], HEX);
      pre_iv[i] = my_iv[i];
    }
    Serial.println("");
    
    //Serial.println(String(b64data));
    /*Serial.print(" key:");
    for (int i=0; i<sizeof(key); i++) {
      if (key[i]<16) {
        Serial.print(0,HEX);
      }
      Serial.print(key[i], HEX);
    }
    Serial.println("");*/


    //Serial.println(" Message: " + msg );
    int b64len = base64_encode(b64data, (char *)msg.c_str(),msg.length());
   // Serial.println (" Message in B64: " + String(b64data) );
     
    //Serial.println (" The lenght is:  " + String(b64len) );
    //aes.do_aes_encrypt((byte *)(char *)msg.c_str(), msg.length() , cipher, key, 128, my_iv);
    aes.do_aes_encrypt((byte *)b64data, b64len , cipher, key, 128, my_iv);

    //Serial.println("Encryption done!");

    int iv_cipher_size = sizeof(my_iv) + aes.get_size();
    char* iv_cipher = (char*)malloc(sizeof(char) * iv_cipher_size);
    
     for (int i=0; i<sizeof(pre_iv); i++) {
      if (pre_iv[i]<16) {
        //Serial.print(0,HEX);
      }
      //Serial.print(pre_iv[i], HEX);
      iv_cipher[i] = pre_iv[i];
    }
    //Serial.println("");
    //strcpy(iv_cipher, (char*)cipher); // PANIC
    /*
    Serial.print("Cipher: ");
    for (int i=0; i<aes.get_size(); i++) {
      if (cipher[i] <16) {
        Serial.print(0, HEX);
      }
      Serial.print(cipher[i], HEX); 
    }
    Serial.println("");
    */
    for (int i=0; i<aes.get_size();i++) {
      iv_cipher[i+16] = cipher[i];
    }
    
    //Serial.println("Cipher size: " + String(aes.get_size()));
    //base64_encode(b64data, (char *)cipher, aes.get_size());
    base64_encode(b64data, iv_cipher, iv_cipher_size);
    //Serial.println ("Encrypted data in base64: " + String(b64data) );
    
    //Serial.println("");
    Serial.print("Encrypted: " + String(b64data));
    /*
    for (int i=0; i<iv_cipher_size; i++) {
      if (iv_cipher[i]<16) {
        Serial.print(0, HEX);
      }
      Serial.print(iv_cipher[i], HEX);
    }
    Serial.println("");
    */
    free(iv_cipher);

    return String(b64data);
}

String decrypt(String cipher_b64) {
   char b64_decoded[200];
    byte msg[500];
    byte iv [N_BLOCK] ;
    byte cipher[500];

    
    aes.set_key( key , sizeof(key));  // Get the globally defined key
    int encrypted_length = base64_decode(b64_decoded, (char *)cipher_b64.c_str(), cipher_b64.length());
    int cipher_length = encrypted_length-16;
    
    //Serial.print("iv: ");
    for (int i=0; i<sizeof(iv); i++) {
      iv[i] = b64_decoded[i];
      if (b64_decoded[i] < 16) {
        //Serial.print(0, HEX);
      }
      //Serial.print(b64_decoded[i], HEX);
    }
    //Serial.println("");
    
    //Serial.print("Cipher: ");
    for (int i=0; i<cipher_length; i++) {
      cipher[i] = b64_decoded[i+16];
      if (cipher[i] <16) {
        //Serial.print(0, HEX);
      }
      //Serial.print(cipher[i], HEX);
    }
    //Serial.println();

    aes.do_aes_decrypt((byte*)cipher, cipher_length, msg, key, 128, iv);
        
    base64_decode(b64_decoded, (char*)msg, aes.get_size());


    String plain_text = String(b64_decoded);
    //Serial.println("Plain: " + plain_text);
    

    return plain_text;
}

 

Golang 서버쪽입니다.

package lib

import (
	"bytes"
	"crypto/aes"
	"crypto/cipher"
	"crypto/rand"
	"encoding/base64"
	"errors"
	"fmt"
	"io"
	"log"
	"strings"
)

func encrypt(b cipher.Block, plaintext []byte) []byte {
	plaintext = pkcs7Padding(plaintext)
	/*if mod := len(plaintext) % aes.BlockSize; mod != 0 { // 블록 크기의 배수가 되어야함
		padding := make([]byte, aes.BlockSize-mod)   // 블록 크기에서 모자라는 부분을
		plaintext = append(plaintext, padding...)    // 채워줌
	}*/

	ciphertext := make([]byte, aes.BlockSize+len(plaintext)) // 초기화 벡터 공간(aes.BlockSize)만큼 더 생성
	iv := ciphertext[:aes.BlockSize]                         // 부분 슬라이스로 초기화 벡터 공간을 가져옴

	if _, err := io.ReadFull(rand.Reader, iv); err != nil { // 랜덤 값을 초기화 벡터에 넣어줌
		fmt.Println(err)
		return nil
	}

	mode := cipher.NewCBCEncrypter(b, iv)                   // 암호화 블록과 초기화 벡터를 넣어서 암호화 블록 모드 인스턴스 생성
	mode.CryptBlocks(ciphertext[aes.BlockSize:], plaintext) // 암호화 블록 모드 인스턴스로
	// 암호화

	return ciphertext
}

func decrypt(b cipher.Block, ciphertext []byte) []byte {
	if len(ciphertext)%aes.BlockSize != 0 { // 블록 크기의 배수가 아니면 리턴
		fmt.Println("암호화된 데이터의 길이는 블록 크기의 배수가 되어야합니다.")
		return nil
	}

	iv := ciphertext[:aes.BlockSize]        // 부분 슬라이스로 초기화 벡터 공간을 가져옴
	ciphertext = ciphertext[aes.BlockSize:] // 부분 슬라이스로 암호화된 데이터를 가져옴

	plaintext := make([]byte, len(ciphertext)) // 평문 데이터를 저장할 공간 생성
	mode := cipher.NewCBCDecrypter(b, iv)      // 암호화 블록과 초기화 벡터를 넣어서
	// 복호화 블록 모드 인스턴스 생성
	mode.CryptBlocks(plaintext, ciphertext) // 복호화 블록 모드 인스턴스로 복호화

	return plaintext
}

func Decrypt(CipherText []byte, key []byte) ([]byte, error) {
	cipher, _ := base64.StdEncoding.DecodeString(string(CipherText))
	if len(cipher) < 32 {
		// b64디코딩했는데 최소 크기(aes 128bit=16bytes / iv+cipher=32bytes)보다 작으면 잘못된거죠
		return []byte(""), errors.New("invalid ciphertext len")
	}
	block, err := aes.NewCipher([]byte(key)) // AES 대칭키 암호화 블록 생성
	if err != nil {
		fmt.Println(err)
		return nil, err
	}
	plaintext := decrypt(block, cipher)
	//log.Printf("%s\n",plaintext)
	unpaded, _ := pkcs7UnPadding(plaintext)
	//fmt.Printf("%s\n", unpaded)
	tp, _ := base64DecodeStripped(string(unpaded))
	//tp,_ := base64DecodeStripped(string(plaintext))
	//fmt.Println(string(tp))
	return tp, nil

}
func Encrypt(plainText []byte, key []byte) ([]byte, error) {
	log.Printf("PlainText: %s\t", string(plainText))
	plainTextb64 := base64.StdEncoding.EncodeToString(plainText)

	block, err := aes.NewCipher([]byte(key)) // AES 대칭키 암호화 블록 생성
	if err != nil {
		fmt.Println(err)
		return []byte(""), err
	}
	ciphertext := encrypt(block, []byte(plainTextb64)) // 평문을 AES 알고리즘으로 암호화
	cipherb64 := base64.StdEncoding.EncodeToString(ciphertext)
	return []byte(cipherb64), nil
}

func pkcs7Padding(src []byte) []byte {
	padding := aes.BlockSize - len(src)%aes.BlockSize
	padtext := bytes.Repeat([]byte{byte(padding)}, padding)
	return append(src, padtext...)
}

func pkcs7UnPadding(src []byte) ([]byte, error) {
	length := len(src)
	unpadding := int(src[length-1])

	if unpadding > aes.BlockSize || unpadding == 0 {
		return nil, errors.New("Invalid pkcs7 padding (unpadding > aes.BlockSize || unpadding == 0)")
	}

	pad := src[len(src)-unpadding:]
	for i := 0; i < unpadding; i++ {
		if pad[i] != byte(unpadding) {
			return nil, errors.New("Invalid pkcs7 padding (pad[i] != unpadding)")
		}
	}

	return src[:(length - unpadding)], nil
}
func base64DecodeStripped(s string) ([]byte, error) {
	if i := len(s) % 4; i != 0 {
		s += strings.Repeat("=", 4-i)
	}
	decoded, err := base64.StdEncoding.DecodeString(s)
	return decoded, err
}

참고 사이트에 있던 내용에 base64와 pkcs를 추가해서 조금 수정했습니다.
실제 사용은 Encrypt와 Decrypt를 호출해서 사용하면 되고
사용예제는 아래 lib.Decrypt 부분입니다

			data := recvBuf[:n]
			log.Printf("%s\n", data)

			if len(data) > 20 {
				data, err = lib.Decrypt(data, key)
				if err != nil {
					// error handle
					log.Printf("error: %s\n", err)
					conn.Close()
					return
				} else {
					log.Println(string(data))
				}

			} else {
				log.Println("Data Strange")
				//conn.Write([]byte("Data strange")) // TODO 필요하면 암호화해서 전송해야함
				conn.Close()

				return
			}