Создайте точную модель распознавания рукописного ввода с помощью PyTorch! Узнайте, как использовать пакет MLTU для упрощения конвейера обучения моделей PyTorch, и найдите методы повышения точности вашей модели.

Самые продвинутые планы развития науки о данных, которые вы когда-либо видели! Поставляется с тысячами бесплатных учебных ресурсов и интеграцией ChatGPT! https://aigents.co/learn/roadmaps/intro

В предыдущем уроке я показал вам, как создать пользовательскую модель PyTorch и обучить ее в оболочке, чтобы добиться модульности в нашем конвейере обучения. Это руководство расширит предыдущие руководства до этого, используя Набор данных IAM. Ранее я показал вам, как использовать TensorFlow для обучения модели распознаванию рукописного текста на изображениях. Сейчас сделаю ту же задачу, но с PyTorch!

Набор данных IAM содержит рукописные текстовые изображения, и целью, связанной с каждым образцом, является соответствующая текстовая строка в изображении. Поскольку набор данных IAM обычно используется в качестве эталона для систем OCR, использование этого примера может стать ценной основой для построения вашей собственной системы OCR.

Распознавание рукописного ввода относится к процессу преобразования рукописного текста в текст, который могут интерпретировать машины. Эта технология широко используется в нескольких приложениях, таких как сканирование документов, распознавание рукописных заметок и чтение рукописных форм. Такие приложения включают в себя оцифровку документов, анализ почерка и автоматизацию выставления оценок на экзаменах. Один из подходов к распознаванию рукописного ввода включает использование функции потери временной классификации Connectionist (CTC), которую мы использовали в моих предыдущих руководствах.

Предпосылки:

Прежде чем мы начнем, вам необходимо установить следующее программное обеспечение и пакеты:

  • Питон 3;
  • факел (в этом уроке мы будем использовать версию 1.13.1);
  • млту==1.0.2
  • тензорная доска == 2.10.1
  • onnx==1.12.0
  • факелсводкаX

В этом уроке мы рассмотрим фрагменты кода, используемые для обучения модели распознавания рукописных слов. Код написан на Python и использует PyTorch в качестве среды глубокого обучения. Модель обучается с использованием набора данных IAM, популярного набора данных для распознавания рукописного ввода. Код использует несколько библиотек и методов машинного обучения для предварительной обработки данных, их дополнения и обучения модели глубокого обучения.

Мы начнем с просмотра фрагмента кода построчно, понимая, что делает каждая строка, а затем обсудим, как использовать этот код для обучения модели распознавания рукописных слов.

Начнем с импорта:

Код импортирует несколько модулей и библиотек Python, которые используются в процессе обучения:

  • Модуль os используется для взаимодействия с операционной системой, например, для чтения или записи файлов, создания каталогов и т. д.;
  • Модули tarfile и zipfile используются для извлечения файлов из архивов, а urlopen — для скачивания файлов из Интернета;
  • Модуль tqdm используется для отображения индикаторов выполнения во время обработки данных, что особенно полезно при работе с большими наборами данных;
  • Модуль torch является основным модулем PyTorch и используется для создания и обучения моделей глубокого обучения.
  • Класс DataProvider — это пользовательский класс, который используется для управления и предварительной обработки набора данных. Класс ImageReader считывает изображения из файловой системы, а ImageResizer используется для изменения размера изображений.
  • Класс LabelIndexer преобразует текстовые метки в целочисленные индексы, а LabelPadding дополняет метки до фиксированной длины.
  • Классы RandomBrightness, RandomRotate, RandomErodeDilate и RandomSharpen используются для увеличения данных. Увеличение данных — это метод, используемый для создания дополнительных обучающих данных путем применения преобразований к существующим данным. Эти методы помогают уменьшить переоснащение и улучшить производительность модели;
  • Класс Model — это пользовательский класс, который используется для обучения и оценки модели глубокого обучения.
  • Класс CTCLoss представляет собой пользовательскую реализацию функции потери Connectionist Temporal Classification, которая обычно используется для обучения моделей распознавания текста;
  • Классы CERMetric и WERMetric используются для расчета показателей частоты ошибок в символах (CER) и частоты ошибок в словах (WER) во время обучения;
  • Классы EarlyStopping, ModelCheckpoint, TensorBoard, Model2onnx и ReduceLROnPlateau используются в качестве обратных вызовов во время обучения. Обратные вызовы — это функции, которые вызываются во время обучения через определенные промежутки времени. Их можно использовать для ранней остановки обучения, сохранения модели, визуализации показателей обучения и выполнения других полезных функций;
  • Network — это просто импортированная архитектура модели нейронной сети PyTorch, которую вы можете найти в файле model.py для более подробной информации. И ModelConfigs - это наши тренировочные конфигурации, такие как ширина и высота входного изображения, скорость обучения, словарный запас и т. Д. Все, что необходимо для нашего тренировочного процесса.

Загрузка и извлечение набора данных

Следующим шагом является загрузка и извлечение набора данных:

Набор данных загружается и извлекается с помощью функции download_and_unzip. Функция принимает URL-адрес набора данных в качестве входных данных и загружает его в каталог «Datasets». Функция использует urlopen для открытия URL-адреса и tqdm для отображения индикатора выполнения при загрузке набора данных. После загрузки набора данных он извлекается с помощью ZipFile. Извлеченный набор данных сохраняется в каталоге «Datasets/IAM_Words/words».

Предварительная обработка набора данных

После того, как набор данных загружен и извлечен, следующим шагом является предварительная обработка набора данных:

Этот фрагмент кода выполняет предварительную обработку данных путем анализа файла words.txt и заполнения трех переменных: dataset, vocab и max_len. Набор данных представляет собой список, содержащий списки. Каждый внутренний список содержит путь к файлу и его метку. Словарь представляет собой набор уникальных символов, присутствующих в метках. max_len — это максимальная длина этикеток.

Для каждой строки в файле код выполняет следующие задачи:

  1. Пропускает строку, если она начинается с #;
  2. Пропускает строку, если второй элемент после разделения строки пробелом имеет значение «err»;
  3. Извлекает первые три и восемь символов имени файла и метки соответственно;
  4. Объединяет dataset_path с извлеченными именами папок и файлов, чтобы сформировать путь к файлу;
  5. Пропускает строку, если путь к файлу не существует;
  6. В противном случае он добавляет путь к файлу и label в список наборов данных. Кроме того, он обновляет набор слов с помощью символов, присутствующих в метке, и обновляет переменную max_len, чтобы она содержала максимальное значение текущего max_len и длину метки;

После предварительной обработки набора данных код сохраняет словарь и максимальную длину текста в конфигах, используя класс ModelConfigs.

Подготовьте поставщика данных:

Набор данных IAM Words содержит изображения рукописного текста и соответствующие им транскрипции. Предварительная обработка выполняется с использованием класса DataProvider из библиотеки mltu.torch:

Класс DataProvider инициализируется со следующими параметрами:

  • набор данных: набор данных, который будет использоваться для обучения и проверки;
  • skip_validation: логическое значение, указывающее, следует ли пропускать проверку или нет;
  • batch_size: размер пакета для обучения модели глубокого обучения;
  • data_preprocessors: список препроцессоров данных, которые нужно применить к набору данных. В данном коде используется только ImageReader(), который считывает изображения из файловой системы;
  • преобразователи: список преобразователей, которые необходимо применить к набору данных. В данном коде используются следующие преобразователи:
  • ImageResizer(): Изменяет размеры изображений до указанной высоты и ширины;
  • LabelIndexer(): преобразует транскрипции в индексы;
  • LabelPadding(): дополняет транскрипцию значением заполнения, чтобы сделать
  • use_cache: отметьте, следует ли хранить наши обучающие данные в оперативной памяти для более быстрой предварительной обработки или нет.

Мы можем использовать гораздо больше функций с этим объектом, но сейчас это самое важное, что нам нужно для нашего проекта.

Мы можем использовать функцию ImageShowCV2() для визуализации наших изображений, которые передаются в data_provider во время итерации нашего генератора data_provider. Раскомментируйте эту строку и сразу после нее вставьте следующие строки:

И теперь, если мы запустим наш скрипт до этого момента, мы должны увидеть похожие результаты на нашем экране:

После создания поставщика данных код разбивает его на наборы для обучения и проверки, используя метод разделения:

Теперь, когда у нас есть разделение наших данных обучения и проверки, нам нужно добавить несколько методов дополнения к нашему поставщику данных обучения, которые помогут нам обучить лучшую модель:

Создайте модель PyTorch:

После создания поставщиков данных мы создаем объект Network, передавая длину словаря и другие параметры конфигурации, такие как функция активации и процент отсева:

Это не тема для рассмотрения архитектуры модели, но если вы хотите проверить ее, перейдите к файлу model.py, который находится между обучающими файлами.

Затем мы создаем объект optim.Adam для оптимизатора и объект CTCLoss для функции потерь. CTCLoss используется потому, что решаемая проблема представляет собой проблему распознавания символов, где длина входной и выходной последовательности не обязательно одинакова:

Если нам интересно, мы можем распечатать сводку нашей модели:

Это даст следующие результаты:

И если у нас есть GPU на нашем устройстве, необходимо разместить нашу модель на устройстве GPU:

Теперь давайте перейдем к следующему разделу кода, где класс Model определен в модуле mltu.torch.model. Класс Model — это высокоуровневый интерфейс, объединяющий низкоуровневую функциональность PyTorch, упрощающую обучение и тестирование моделей нейронных сетей. Он берет на себя весь шаблонный код, необходимый для обучения нейронной сети, такой как прямые и обратные проходы, расчет потерь и обновление параметров, позволяя пользователю сосредоточиться на архитектуре модели и данных.

Класс Model принимает четыре основных аргумента: network, optimizer, loss и metrics. Network — это экземпляр модели нейронной сети PyTorch, которая принимает входные тензоры и возвращает выходные тензоры. Optimizer — это экземпляр оптимизатора PyTorch, который принимает параметры сети в качестве входных данных и обновляет их на основе градиентов, вычисленных во время обратного распространения. Loss — это экземпляр функции потерь PyTorch, которая вычисляет разницу между прогнозами сети и фактическими метками. Metrics — это список экземпляров пользовательских метрик, которые оценивают производительность сети на проверочном наборе во время обучения.

Класс Model имеет несколько методов для обучения и тестирования модели. Метод fit обучает модель в течение определенного количества эпох, используя поставщиков данных для обучения и проверки, и применяет указанные обратные вызовы на каждом этапе обучения. Метод оценки оценивает производительность модели в заданном наборе данных с использованием указанных показателей.

В коде объект модели создается с экземпляром нейронной сети, оптимизатором Adam, потерей функции потерь CTCLoss и списком из двух пользовательских метрик: CERMetric и WERMetric. CERMetric вычисляет коэффициент ошибок символов (CER), который представляет собой отношение количества неправильных символов к общему количеству символов в прогнозах. WERMetric вычисляет коэффициент ошибок в словах (WER), который представляет собой отношение количества неправильно распознанных слов к общему количеству слов в предсказаниях.

Наконец, метод подгонки экземпляра модели вызывается с наборами данных train_dataProvider и test_dataProvider, количеством эпох, равным 1000, и списком из пяти обратных вызовов: earlyStopping, modelCheckpoint, tb_callback, reduce_lr и model2onnx. Эти обратные вызовы определены в модуле mltu.torch.callbacks и используются для сохранения лучшей модели на основе потери проверки, досрочной остановки процесса обучения, если производительность модели не улучшается в течение заданного количества эпох, регистрации показателей производительности модели с помощью TensorBoard, уменьшения скорость обучения, когда производительность модели стабилизируется, и сохраните модель в формате ONNX для развертывания.

В конце процесса обучения набор данных train_dataProvider и набор данных test_dataProvider сохраняются в виде файлов CSV в каталоге configs.model_path. Эти файлы CSV можно использовать для дальнейшего анализа набора данных и производительности модели.

Производительность обучения на TensorBoard:

Точно так же, как мы это делаем в TensorFlow, мы можем проверять журналы обучения и проверки в Tensorboard. В моем случае это так же просто, как вызвать тензорную доску с путем к нашей папке журналов:

Он дает нам ссылку, которую нам нужно открыть:

Там мы можем анализировать целые кривые обучения и проверки, и меня больше всего интересует кривая CER (коэффициент ошибок):

Из приведенной выше кривой (CER) мы видим, что наша модель определенно обучалась, пока кривая продолжала уменьшаться. Мы видим, что все обучение заняло около 400 тренировочных эпох, а лучшая модель сохранилась где-то на 350-м шаге. Там он остается где-то около 0,12 CER; это неплохо! Это означает, что из строки вероятность того, что наша модель сделает ошибку, составляет около 12%. Сравнительно ошеломляющие результаты!

Приведенная выше кривая потерь не дает нам никакой полезной информации, кроме того, что она продолжает уменьшаться, что означает, что наша модель продолжает обучение.

Модель ONNX, прошедшая тестирование:

Моя обученная модель была сохранена в формате «Models/08_handwriting_recognition_torch/202303142139/model.onnx» onnx, что позволяет нам загружать ее с выводом onnx и использовать прямо из коробки! Поскольку наше обучение завершено, мы хотим протестировать его, чтобы увидеть фактические прогнозы в строковом формате. Вот код для циклического просмотра нашего набора данных проверки:

Приведенный выше код определяет класс с именем ImageToWordModel, который расширяет класс OnnxInferenceModel из модуля mltu.inferenceModel. Этот класс предназначен для предсказания слова по изображению почерка.

Метод predict берет входное изображение и возвращает предсказанный текст, сначала изменяя размер изображения до входной формы модели, пропуская изображение через модель, а затем используя декодер CTC для декодирования выходных предсказаний. Прогнозируемый текст возвращается методом predict.

Блок __main__ импортирует модули pandas и tqdm и создает экземпляр класса ImageToWordModel, передавая путь к файлу модели ONNX. Файл CSV, содержащий пути к изображениям и соответствующие им метки, считывается во фрейм данных Pandas, и метод прогнозирования вызывается для каждого изображения во фрейме данных.

Прогнозируемый текст сравнивается с наземной меткой истинности для каждого изображения с использованием функции get_cer из модуля mltu.utils.text_utils для расчета коэффициента ошибок символов (CER). Путь к изображению, метка истинности, предсказанный текст и CER выводятся на консоль. Значения CER для всех изображений накапливаются в списке, а среднее значение CER выводится на консоль после обработки всех изображений.

Чтобы использовать этот код для задачи распознавания рукописного ввода, необходимо предоставить собственный файл модели ONNX и файл CSV, содержащий пути к изображениям и соответствующие им метки. Затем вы можете изменить метод прогнозирования, чтобы выполнить любые дополнительные шаги предварительной или последующей обработки, необходимые для вашей конкретной задачи.

Если я запускаю приведенный выше скрипт в консоли, он дает мне следующие результаты:

Здесь мы можем увидеть фактическую метку в наборе данных и прогнозируемую метку. Кроме того, он сообщает нам CER, и мы можем понять, где он допустил какие-либо ошибки, если они это сделали. Вот несколько изображений из моего тестового набора данных:

Заключение:

В заключение в этом руководстве представлена ​​реализация модели PyTorch для распознавания рукописного текста на изображениях с использованием набора данных IAM. Мы использовали несколько методов машинного обучения для предварительной обработки данных, их дополнения и обучения модели глубокого обучения, включая увеличение данных и функцию потери временной классификации Connectionist (CTC). Набор данных IAM является популярным эталоном для систем OCR, что делает это руководство отличной отправной точкой для создания вашей системы OCR.

В руководстве также рассматривалась важность обратных вызовов и реализация пользовательских классов для управления набором данных и его предварительной обработки, расчета показателей оценки и сохранения модели. В целом, это руководство предоставило исчерпывающее руководство по созданию модели распознавания рукописных слов с использованием PyTorch, которая может быть полезна в нескольких приложениях, включая оцифровку документов, анализ рукописного текста и автоматизацию оценки экзаменов.

Обученную модель, используемую в этом руководстве, можно скачать по этой ссылке.

Полный код этого урока вы можете найти по этой ссылке на GitHub.

Первоначально опубликовано на https://pylessons.com/handwriting-recognition-pytorch