Создайте точную модель распознавания рукописного ввода с помощью PyTorch! Узнайте, как использовать пакет MLTU для упрощения конвейера обучения моделей PyTorch, и найдите методы повышения точности вашей модели.
Самые продвинутые планы развития науки о данных, которые вы когда-либо видели! Поставляется с тысячами бесплатных учебных ресурсов и интеграцией ChatGPT! https://aigents.co/learn/roadmaps/intro
В предыдущем уроке я показал вам, как создать пользовательскую модель PyTorch и обучить ее в оболочке, чтобы добиться модульности в нашем конвейере обучения. Это руководство расширит предыдущие руководства до этого, используя Набор данных IAM. Ранее я показал вам, как использовать TensorFlow для обучения модели распознаванию рукописного текста на изображениях. Сейчас сделаю ту же задачу, но с PyTorch!
Набор данных IAM содержит рукописные текстовые изображения, и целью, связанной с каждым образцом, является соответствующая текстовая строка в изображении. Поскольку набор данных IAM обычно используется в качестве эталона для систем OCR, использование этого примера может стать ценной основой для построения вашей собственной системы OCR.
Распознавание рукописного ввода относится к процессу преобразования рукописного текста в текст, который могут интерпретировать машины. Эта технология широко используется в нескольких приложениях, таких как сканирование документов, распознавание рукописных заметок и чтение рукописных форм. Такие приложения включают в себя оцифровку документов, анализ почерка и автоматизацию выставления оценок на экзаменах. Один из подходов к распознаванию рукописного ввода включает использование функции потери временной классификации Connectionist (CTC), которую мы использовали в моих предыдущих руководствах.
Предпосылки:
Прежде чем мы начнем, вам необходимо установить следующее программное обеспечение и пакеты:
- Питон 3;
- факел (в этом уроке мы будем использовать версию 1.13.1);
- млту==1.0.2
- тензорная доска == 2.10.1
- onnx==1.12.0
- факелсводкаX
В этом уроке мы рассмотрим фрагменты кода, используемые для обучения модели распознавания рукописных слов. Код написан на Python и использует PyTorch в качестве среды глубокого обучения. Модель обучается с использованием набора данных IAM, популярного набора данных для распознавания рукописного ввода. Код использует несколько библиотек и методов машинного обучения для предварительной обработки данных, их дополнения и обучения модели глубокого обучения.
Мы начнем с просмотра фрагмента кода построчно, понимая, что делает каждая строка, а затем обсудим, как использовать этот код для обучения модели распознавания рукописных слов.
Начнем с импорта:
Код импортирует несколько модулей и библиотек Python, которые используются в процессе обучения:
- Модуль
os
используется для взаимодействия с операционной системой, например, для чтения или записи файлов, создания каталогов и т. д.; - Модули
tarfile
иzipfile
используются для извлечения файлов из архивов, аurlopen
— для скачивания файлов из Интернета; - Модуль
tqdm
используется для отображения индикаторов выполнения во время обработки данных, что особенно полезно при работе с большими наборами данных; - Модуль
torch
является основным модулем PyTorch и используется для создания и обучения моделей глубокого обучения. - Класс
DataProvider
— это пользовательский класс, который используется для управления и предварительной обработки набора данных. КлассImageReader
считывает изображения из файловой системы, аImageResizer
используется для изменения размера изображений. - Класс
LabelIndexer
преобразует текстовые метки в целочисленные индексы, аLabelPadding
дополняет метки до фиксированной длины. - Классы
RandomBrightness
,RandomRotate
,RandomErodeDilate
иRandomSharpen
используются для увеличения данных. Увеличение данных — это метод, используемый для создания дополнительных обучающих данных путем применения преобразований к существующим данным. Эти методы помогают уменьшить переоснащение и улучшить производительность модели; - Класс
Model
— это пользовательский класс, который используется для обучения и оценки модели глубокого обучения. - Класс
CTCLoss
представляет собой пользовательскую реализацию функции потери Connectionist Temporal Classification, которая обычно используется для обучения моделей распознавания текста; - Классы
CERMetric
иWERMetric
используются для расчета показателей частоты ошибок в символах (CER) и частоты ошибок в словах (WER) во время обучения; - Классы
EarlyStopping
,ModelCheckpoint
,TensorBoard
,Model2onnx
иReduceLROnPlateau
используются в качестве обратных вызовов во время обучения. Обратные вызовы — это функции, которые вызываются во время обучения через определенные промежутки времени. Их можно использовать для ранней остановки обучения, сохранения модели, визуализации показателей обучения и выполнения других полезных функций; Network
— это просто импортированная архитектура модели нейронной сети PyTorch, которую вы можете найти в файлеmodel.py
для более подробной информации. ИModelConfigs
- это наши тренировочные конфигурации, такие как ширина и высота входного изображения, скорость обучения, словарный запас и т. Д. Все, что необходимо для нашего тренировочного процесса.
Загрузка и извлечение набора данных
Следующим шагом является загрузка и извлечение набора данных:
Набор данных загружается и извлекается с помощью функции download_and_unzip
. Функция принимает URL-адрес набора данных в качестве входных данных и загружает его в каталог «Datasets
». Функция использует urlopen
для открытия URL-адреса и tqdm
для отображения индикатора выполнения при загрузке набора данных. После загрузки набора данных он извлекается с помощью ZipFile
. Извлеченный набор данных сохраняется в каталоге «Datasets/IAM_Words/words
».
Предварительная обработка набора данных
После того, как набор данных загружен и извлечен, следующим шагом является предварительная обработка набора данных:
Этот фрагмент кода выполняет предварительную обработку данных путем анализа файла words.txt
и заполнения трех переменных: dataset
, vocab
и max_len
. Набор данных представляет собой список, содержащий списки. Каждый внутренний список содержит путь к файлу и его метку. Словарь представляет собой набор уникальных символов, присутствующих в метках. max_len
— это максимальная длина этикеток.
Для каждой строки в файле код выполняет следующие задачи:
- Пропускает строку, если она начинается с #;
- Пропускает строку, если второй элемент после разделения строки пробелом имеет значение «err»;
- Извлекает первые три и восемь символов имени файла и метки соответственно;
- Объединяет dataset_path с извлеченными именами папок и файлов, чтобы сформировать путь к файлу;
- Пропускает строку, если путь к файлу не существует;
- В противном случае он добавляет путь к файлу и
label
в список наборов данных. Кроме того, он обновляет набор слов с помощью символов, присутствующих в метке, и обновляет переменнуюmax_len
, чтобы она содержала максимальное значение текущегоmax_len
и длину метки;
После предварительной обработки набора данных код сохраняет словарь и максимальную длину текста в конфигах, используя класс ModelConfigs
.
Подготовьте поставщика данных:
Набор данных IAM Words содержит изображения рукописного текста и соответствующие им транскрипции. Предварительная обработка выполняется с использованием класса DataProvider
из библиотеки mltu.torch
:
Класс DataProvider
инициализируется со следующими параметрами:
- набор данных: набор данных, который будет использоваться для обучения и проверки;
- skip_validation: логическое значение, указывающее, следует ли пропускать проверку или нет;
- batch_size: размер пакета для обучения модели глубокого обучения;
- data_preprocessors: список препроцессоров данных, которые нужно применить к набору данных. В данном коде используется только
ImageReader()
, который считывает изображения из файловой системы; - преобразователи: список преобразователей, которые необходимо применить к набору данных. В данном коде используются следующие преобразователи:
ImageResizer()
: Изменяет размеры изображений до указанной высоты и ширины;LabelIndexer()
: преобразует транскрипции в индексы;LabelPadding()
: дополняет транскрипцию значением заполнения, чтобы сделать- use_cache: отметьте, следует ли хранить наши обучающие данные в оперативной памяти для более быстрой предварительной обработки или нет.
Мы можем использовать гораздо больше функций с этим объектом, но сейчас это самое важное, что нам нужно для нашего проекта.
Мы можем использовать функцию ImageShowCV2()
для визуализации наших изображений, которые передаются в data_provider во время итерации нашего генератора data_provider. Раскомментируйте эту строку и сразу после нее вставьте следующие строки:
И теперь, если мы запустим наш скрипт до этого момента, мы должны увидеть похожие результаты на нашем экране:
После создания поставщика данных код разбивает его на наборы для обучения и проверки, используя метод разделения:
Теперь, когда у нас есть разделение наших данных обучения и проверки, нам нужно добавить несколько методов дополнения к нашему поставщику данных обучения, которые помогут нам обучить лучшую модель:
Создайте модель PyTorch:
После создания поставщиков данных мы создаем объект Network, передавая длину словаря и другие параметры конфигурации, такие как функция активации и процент отсева:
Это не тема для рассмотрения архитектуры модели, но если вы хотите проверить ее, перейдите к файлу model.py
, который находится между обучающими файлами.
Затем мы создаем объект optim.Adam
для оптимизатора и объект CTCLoss
для функции потерь. CTCLoss
используется потому, что решаемая проблема представляет собой проблему распознавания символов, где длина входной и выходной последовательности не обязательно одинакова:
Если нам интересно, мы можем распечатать сводку нашей модели:
Это даст следующие результаты:
И если у нас есть GPU на нашем устройстве, необходимо разместить нашу модель на устройстве GPU:
Теперь давайте перейдем к следующему разделу кода, где класс Model определен в модуле mltu.torch.model
. Класс Model — это высокоуровневый интерфейс, объединяющий низкоуровневую функциональность PyTorch, упрощающую обучение и тестирование моделей нейронных сетей. Он берет на себя весь шаблонный код, необходимый для обучения нейронной сети, такой как прямые и обратные проходы, расчет потерь и обновление параметров, позволяя пользователю сосредоточиться на архитектуре модели и данных.
Класс Model принимает четыре основных аргумента: network
, optimizer
, loss
и metrics
. Network
— это экземпляр модели нейронной сети PyTorch, которая принимает входные тензоры и возвращает выходные тензоры. Optimizer
— это экземпляр оптимизатора PyTorch, который принимает параметры сети в качестве входных данных и обновляет их на основе градиентов, вычисленных во время обратного распространения. Loss
— это экземпляр функции потерь PyTorch, которая вычисляет разницу между прогнозами сети и фактическими метками. Metrics
— это список экземпляров пользовательских метрик, которые оценивают производительность сети на проверочном наборе во время обучения.
Класс Model
имеет несколько методов для обучения и тестирования модели. Метод fit
обучает модель в течение определенного количества эпох, используя поставщиков данных для обучения и проверки, и применяет указанные обратные вызовы на каждом этапе обучения. Метод оценки оценивает производительность модели в заданном наборе данных с использованием указанных показателей.
В коде объект модели создается с экземпляром нейронной сети, оптимизатором Adam, потерей функции потерь CTCLoss и списком из двух пользовательских метрик: CERMetric
и WERMetric
. CERMetric вычисляет коэффициент ошибок символов (CER), который представляет собой отношение количества неправильных символов к общему количеству символов в прогнозах. WERMetric вычисляет коэффициент ошибок в словах (WER), который представляет собой отношение количества неправильно распознанных слов к общему количеству слов в предсказаниях.
Наконец, метод подгонки экземпляра модели вызывается с наборами данных train_dataProvider
и test_dataProvider
, количеством эпох, равным 1000, и списком из пяти обратных вызовов: earlyStopping
, modelCheckpoint
, tb_callback
, reduce_lr
и model2onnx
. Эти обратные вызовы определены в модуле mltu.torch.callbacks
и используются для сохранения лучшей модели на основе потери проверки, досрочной остановки процесса обучения, если производительность модели не улучшается в течение заданного количества эпох, регистрации показателей производительности модели с помощью TensorBoard, уменьшения скорость обучения, когда производительность модели стабилизируется, и сохраните модель в формате ONNX для развертывания.
В конце процесса обучения набор данных train_dataProvider
и набор данных test_dataProvider
сохраняются в виде файлов CSV в каталоге configs.model_path
. Эти файлы CSV можно использовать для дальнейшего анализа набора данных и производительности модели.
Производительность обучения на TensorBoard:
Точно так же, как мы это делаем в TensorFlow, мы можем проверять журналы обучения и проверки в Tensorboard. В моем случае это так же просто, как вызвать тензорную доску с путем к нашей папке журналов:
Он дает нам ссылку, которую нам нужно открыть:
Там мы можем анализировать целые кривые обучения и проверки, и меня больше всего интересует кривая CER (коэффициент ошибок):
Из приведенной выше кривой (CER) мы видим, что наша модель определенно обучалась, пока кривая продолжала уменьшаться. Мы видим, что все обучение заняло около 400 тренировочных эпох, а лучшая модель сохранилась где-то на 350-м шаге. Там он остается где-то около 0,12 CER; это неплохо! Это означает, что из строки вероятность того, что наша модель сделает ошибку, составляет около 12%. Сравнительно ошеломляющие результаты!
Приведенная выше кривая потерь не дает нам никакой полезной информации, кроме того, что она продолжает уменьшаться, что означает, что наша модель продолжает обучение.
Модель ONNX, прошедшая тестирование:
Моя обученная модель была сохранена в формате «Models/08_handwriting_recognition_torch/202303142139/model.onnx
» onnx, что позволяет нам загружать ее с выводом onnx и использовать прямо из коробки! Поскольку наше обучение завершено, мы хотим протестировать его, чтобы увидеть фактические прогнозы в строковом формате. Вот код для циклического просмотра нашего набора данных проверки:
Приведенный выше код определяет класс с именем ImageToWordModel
, который расширяет класс OnnxInferenceModel
из модуля mltu.inferenceModel
. Этот класс предназначен для предсказания слова по изображению почерка.
Метод predict
берет входное изображение и возвращает предсказанный текст, сначала изменяя размер изображения до входной формы модели, пропуская изображение через модель, а затем используя декодер CTC для декодирования выходных предсказаний. Прогнозируемый текст возвращается методом predict
.
Блок __main__
импортирует модули pandas
и tqdm
и создает экземпляр класса ImageToWordModel
, передавая путь к файлу модели ONNX. Файл CSV, содержащий пути к изображениям и соответствующие им метки, считывается во фрейм данных Pandas, и метод прогнозирования вызывается для каждого изображения во фрейме данных.
Прогнозируемый текст сравнивается с наземной меткой истинности для каждого изображения с использованием функции get_cer
из модуля mltu.utils.text_utils
для расчета коэффициента ошибок символов (CER). Путь к изображению, метка истинности, предсказанный текст и CER выводятся на консоль. Значения CER для всех изображений накапливаются в списке, а среднее значение CER выводится на консоль после обработки всех изображений.
Чтобы использовать этот код для задачи распознавания рукописного ввода, необходимо предоставить собственный файл модели ONNX и файл CSV, содержащий пути к изображениям и соответствующие им метки. Затем вы можете изменить метод прогнозирования, чтобы выполнить любые дополнительные шаги предварительной или последующей обработки, необходимые для вашей конкретной задачи.
Если я запускаю приведенный выше скрипт в консоли, он дает мне следующие результаты:
Здесь мы можем увидеть фактическую метку в наборе данных и прогнозируемую метку. Кроме того, он сообщает нам CER, и мы можем понять, где он допустил какие-либо ошибки, если они это сделали. Вот несколько изображений из моего тестового набора данных:
Заключение:
В заключение в этом руководстве представлена реализация модели PyTorch для распознавания рукописного текста на изображениях с использованием набора данных IAM. Мы использовали несколько методов машинного обучения для предварительной обработки данных, их дополнения и обучения модели глубокого обучения, включая увеличение данных и функцию потери временной классификации Connectionist (CTC). Набор данных IAM является популярным эталоном для систем OCR, что делает это руководство отличной отправной точкой для создания вашей системы OCR.
В руководстве также рассматривалась важность обратных вызовов и реализация пользовательских классов для управления набором данных и его предварительной обработки, расчета показателей оценки и сохранения модели. В целом, это руководство предоставило исчерпывающее руководство по созданию модели распознавания рукописных слов с использованием PyTorch, которая может быть полезна в нескольких приложениях, включая оцифровку документов, анализ рукописного текста и автоматизацию оценки экзаменов.
Обученную модель, используемую в этом руководстве, можно скачать по этой ссылке.
Полный код этого урока вы можете найти по этой ссылке на GitHub.
Первоначально опубликовано на https://pylessons.com/handwriting-recognition-pytorch