Skip to content

EasyMMS API Reference

easymms.models

tts

This file contains a class definition to use the TTS models from Meta's Massively Multilingual Speech (MMS) project

TTSModel

TTSModel(lang, model_dir=None, log_level=logging.INFO)

TTS class model

Example usage:

from easymms.models.tts import TTSModel

tts = TTSModel('eng')
res = tts.synthesize("This is a simple example")
tts.save(res)

Use a TTS model by its language ISO ID. The model will be downloaded automatically

Parameters:

Name Type Description Default
lang str

TTS model language

required
log_level int

log level

logging.INFO
Source code in easymms/models/tts.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def __init__(self,
             lang: str,
             model_dir: str = None,
             log_level: int = logging.INFO):
    """
    Use a TTS model by its language ISO ID.
    The model will be downloaded automatically

    :param lang: TTS model language
    :param log_level: log level
    """
    self.log_level = log_level
    self.lang = lang
    set_log_level(log_level)
    # check if models_dir is provided
    if model_dir is not None:
        # verify if all files exist
        model_dir_path = Path(model_dir)
    else:
        model_dir_path = self._download_tts_model_files(lang)

    self.cp = model_dir_path / "G_100000.pth"
    assert self.cp.exists(), f"G_100000.pth not found in {str(model_dir_path)}"
    self.config = model_dir_path / "config.json"
    assert self.config.exists(), f"config.json not found in {str(model_dir_path)}"
    self.vocab = model_dir_path / "vocab.txt"
    assert self.vocab.exists(), f"vocab.txt not found in {str(model_dir_path)}"
    self.uroman_dir_path = None

    self._setup()
get_supported_langs staticmethod
get_supported_langs()

Helper function to get supported ISO 693-3 languages by the TTS models Source https://dl.fbaipublicfiles.com/mms/misc/language_coverage_mms.html

Returns:

Type Description
List[str]

list of supported languages

Source code in easymms/models/tts.py
121
122
123
124
125
126
127
128
129
130
@staticmethod
def get_supported_langs() -> List[str]:
    """
    Helper function to get supported ISO 693-3 languages by the TTS models
    Source <https://dl.fbaipublicfiles.com/mms/misc/language_coverage_mms.html>
    :return: list of supported languages
    """
    with open(constants.MMS_LANGS_FILE) as f:
        data = json.load(f)
        return [key for key in data if data[key]['TTS']]
synthesize
synthesize(txt, device=None)

Synthesizes the text provided as input.

Parameters:

Name Type Description Default
txt str

Text

required
lang

Language

required
device

Pytorch device (cpu/cuda)

None

Returns:

Type Description

Tuple(data, sample_rate)

Source code in easymms/models/tts.py
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
def synthesize(self, txt: str, device=None):
    """
     Synthesizes the text provided as input.

    :param txt: Text
    :param lang: Language
    :param device: Pytorch device (cpu/cuda)
    :return: Tuple(data, sample_rate)
    """
    cwd = os.getcwd()
    os.chdir(constants.VITS_DIR)
    from utils import get_hparams_from_file, load_checkpoint
    from models import SynthesizerTrn
    os.chdir(constants.FAIRSEQ_DIR)
    from examples.mms.tts.infer import TextMapper
    os.chdir(cwd)
    if device is None:
        if torch.cuda.is_available():
            device = 'cuda'
        else:
            device = 'cpu'
    else:
        assert device in ['cpu', 'cuda']

    text_mapper = TextMapper(str(self.vocab))
    hps = get_hparams_from_file(str(self.config))
    net_g = SynthesizerTrn(
        len(text_mapper.symbols),
        hps.data.filter_length // 2 + 1,
        hps.train.segment_size // hps.data.hop_length,
        **hps.model)
    net_g.to(device)
    _ = net_g.eval()
    g_pth = self.cp
    logger.info(f"loading {g_pth} ...")
    _ = load_checkpoint(g_pth, net_g, None)
    logger.info(f"text: {txt}")
    is_uroman = hps.data.training_files.split('.')[-1] == 'uroman'
    if is_uroman:
        uroman_pl = str(self.uroman_dir_path / "uroman.pl")
        txt = text_mapper.uromanize(txt, uroman_pl)
        logger.info(f"uroman text: {txt}")
    txt = txt.lower()
    txt = text_mapper.filter_oov(txt, lang=self.lang)
    stn_tst = text_mapper.get_text(txt, hps)
    # inference
    with torch.no_grad():
        x_tst = stn_tst.unsqueeze(0).to(device)
        x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device)
        hyp = net_g.infer(
            x_tst, x_tst_lengths, noise_scale=.667,
            noise_scale_w=0.8, length_scale=1.0
        )[0][0, 0].cpu().float().numpy()

    return hyp, hps.data.sampling_rate
save
save(tts_data, out_file='out.wav')

Saves the results of the synthesize function to a file

Parameters:

Name Type Description Default
tts_data Tuple

tts_data: a tuple of wav data array and sample rate

required
out_file

output file path

'out.wav'

Returns:

Type Description
Path

out_file absolute path

Source code in easymms/models/tts.py
188
189
190
191
192
193
194
195
196
197
198
199
200
def save(self, tts_data: Tuple, out_file='out.wav') -> Path:
    """
    Saves the results of the `synthesize` function to a file

    :param tts_data: tts_data: a tuple of `wav data array` and `sample rate`
    :param out_file: output file path

    :return: out_file absolute path
    """
    set_log_level(self.log_level)
    logger.info(f"Saving audio file to {out_file}")
    sf.write(out_file, tts_data[0], tts_data[1])
    return out_file

alignment

This file contains a simple class to use the AlignmentModel model from the MMS project More info can be found here: https://github.com/facebookresearch/fairseq/tree/main/examples/mms/data_prep

AlignmentModel

AlignmentModel(
    model=None,
    dictionary=None,
    uroman_dir=None,
    log_level=logging.INFO,
)

MMS Alignment algorithm Example usage:

from easymms.models.alignment import AlignmentModel

align_model = AlignmentModel()
transcriptions = align_model.align('/home/su/code/easymms/assets/eng_1.mp3',
                                   transcript=["segment 1", "segment 2"],
                                   lang='eng')
for transcription in transcriptions:
    for segment in transcription:
        print(f"{segment['start_time']} -> {segment['end_time']}: {segment['text']}")
Source code in easymms/models/alignment.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def __init__(self,
             model: str = None,
             dictionary: str = None,
             uroman_dir: str = None,
             log_level: int = logging.INFO):

    set_log_level(log_level)
    assert shutil.which("perl") is not None, "To use the alignment algorithm you will need uroman " \
                                             "<https://github.com/isi-nlp/uroman> which is written in perl " \
                                             "please install perl first <https://www.perl.org/get.html>"
    if uroman_dir is not None:
        self.uroman_dir_path = Path(uroman_dir)
    else:
        self.uroman_dir_path = easymms_utils.get_uroman()

    if model is not None:
        self.model_path = Path(model)
    else:
        self.model_path = Path(PACKAGE_DATA_DIR) / "ctc_alignment_mling_uroman_model.pt"

    if dictionary is not None:
        self.dictionary_path = Path(dictionary)
    else:
        self.dictionary_path = Path(PACKAGE_DATA_DIR) / "ctc_alignment_mling_uroman_model.dict"

    self.model, self.dictionary = self._load_model_dict()

    # clone Fairseq
    easymms_utils.clone(constants.FAIRSEQ_URL, constants.FAIRSEQ_DIR)
    sys.path.append(str(constants.FAIRSEQ_DIR.resolve()))
align
align(media_file, transcript, lang, device=None)

Takes a media file, transcription segments and the lang and returns a list of dicts in the following format [{ 'start_time': ... 'end_time': ..., 'text': ..., 'duration': ... }, ...]

Parameters:

Name Type Description Default
media_file str

the path of the media file, should be wav

required
transcript List[str]

list of segments

required
lang str

language ISO code

required
device str

'cuda' or 'cpu'

None

Returns:

Type Description
List[dict]

list of transcription timestamps

Source code in easymms/models/alignment.py
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
def align(self,
          media_file: str,
          transcript: List[str],
          lang: str,
          device: str = None) -> List[dict]:
    """
    Takes a media file, transcription segments and the lang and returns a list of dicts in the following format
    [{
        'start_time': ...
        'end_time': ...,
        'text': ...,
        'duration': ...
    }, ...]

    :param media_file: the path of the media file, should be wav
    :param transcript: list of segments
    :param lang: language ISO code
    :param device: 'cuda' or 'cpu'
    :return: list of transcription timestamps
    """
    # import
    import os
    cwd = os.getcwd()
    os.chdir(constants.FAIRSEQ_DIR)
    from examples.mms.data_prep.align_and_segment import get_alignments
    from examples.mms.data_prep.align_utils import get_uroman_tokens, get_spans
    from examples.mms.data_prep.text_normalization import text_normalize
    os.chdir(cwd)

    if device is None:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = self.model.to(device)
    res = []
    logger.info(f"Aligning file {media_file} ...")
    norm_transcripts = [text_normalize(line.strip(), lang) for line in transcript]
    tokens = get_uroman_tokens(norm_transcripts, str(self.uroman_dir_path.resolve()), lang)

    segments, stride = get_alignments(
        media_file,
        tokens,
        model,
        self.dictionary,
        use_star=False,
    )
    # Get spans of each line in input text file
    spans = get_spans(tokens, segments)

    for i, t in enumerate(transcript):
        span = spans[i]
        seg_start_idx = span[0].start
        seg_end_idx = span[-1].end
        audio_start_sec = seg_start_idx * stride / 1000
        audio_end_sec = seg_end_idx * stride / 1000
        res.append({'start_time': audio_start_sec,
                    'end_time': audio_end_sec,
                    'text': t,
                    'duration': audio_end_sec - audio_start_sec})

    return res

asr

This file contains a class definition to use the ASR models from Meta's Massively Multilingual Speech (MMS) project

ASRModel

ASRModel(model, log_level=logging.INFO)

MMS ASR class model

Example usage:

from easymms.models.asr import ASRModel

asr = ASRModel(model='/path/to/mms/model')
files = ['path/to/media_file_1', 'path/to/media_file_2']
transcriptions = asr.transcribe(files, lang='eng', align=False)
for i, transcription in enumerate(transcriptions):
    print(f">>> file {files[i]}")
    print(transcription)

Parameters:

Name Type Description Default
model str required
log_level int

log level

logging.INFO
Source code in easymms/models/asr.py
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def __init__(self,
             model: str,
             log_level: int = logging.INFO):
    """
    :param model: path to the asr model <https://github.com/facebookresearch/fairseq/tree/main/examples/mms#asr>
    :param log_level: log level
    """
    set_log_level(log_level)
    self.cfg = constants.CFG.copy()
    self.model = Path(model)
    self.cfg['common_eval']['path'] = str(self.model.resolve())

    self.tmp_dir = tempfile.TemporaryDirectory()
    self.tmp_dir_path = Path(self.tmp_dir.name)
    atexit.register(self._cleanup)

    self.wer = None
transcribe
transcribe(
    media_files,
    lang="eng",
    device=None,
    align=False,
    timestamps_type="segment",
    max_segment_len=27,
    cfg=None,
)

Transcribes a list of media files provided as inputs

Parameters:

Name Type Description Default
media_files List[str]

list of media files (video/audio), in whichever format supported by ffmpeg

required
lang str

the language of the media

'eng'
device str

Pytorch device (cuda, cpu or tpu)

None
align bool

if True the alignment model will be used to generate the timestamps, otherwise you will get raw text from the MMS model

False
timestamps_type str

Once of (segment, word or char) if align is set to True, this will be used to fragment the raw text

'segment'
max_segment_len int

the maximum length of the fragmented segments

27
cfg dict

configuration dict in case you want to use a custom configuration, see CFG

None

Returns:

Type Description
Union[List[str], List[dict]]

List of transcription text in the same order as input files

Source code in easymms/models/asr.py
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
def transcribe(self,
               media_files: List[str],
               lang: str = 'eng',
               device: str = None,
               align: bool = False,
               timestamps_type: str = 'segment',
               max_segment_len: int = 27,
               cfg: dict = None) -> Union[List[str], List[dict]]:
    """
    Transcribes a list of media files provided as inputs

    :param media_files: list of media files (video/audio), in whichever format supported by ffmpeg
    :param lang: the language of the media
    :param device: Pytorch device (`cuda`, `cpu` or `tpu`)
    :param align: if True the alignment model will be used to generate the timestamps, otherwise you will get raw text from the MMS model
    :param timestamps_type: Once of (`segment`, `word` or `char`) if `align` is set to True, this will be used to fragment the raw text
    :param max_segment_len: the maximum length of the fragmented segments
    :param cfg: configuration dict in case you want to use a custom configuration, see [CFG](#Constants.CFG)

    :return: List of transcription text in the same order as input files
    """
    processed_files = self._prepare_media_files(media_files)
    cwd = os.getcwd()
    # clone Fairseq
    easymms_utils.clone(constants.FAIRSEQ_URL, constants.FAIRSEQ_DIR)
    fairseq_dir = str(constants.FAIRSEQ_DIR.resolve())
    sys.path.append(fairseq_dir)
    os.chdir(fairseq_dir)
    # import
    from examples.speech_recognition.new.infer import hydra_main
    try:
        from fairseq.data.data_utils_fast import (
            batch_by_size_fn,
            batch_by_size_vec,
            batch_fixed_shapes_fast,
        )
    except ImportError:
        # we need to build the extension
        logger.info("Bulding required extensions, this may take a while ...")
        from distutils.core import run_setup
        run_setup(str((constants.FAIRSEQ_DIR / 'setup.py').resolve()), script_args=['build_ext', '--inplace'],
                  stop_after='run')


    self._setup_tmp_dir(processed_files)
    # edit cfg
    if cfg is None:
        self.cfg['task']['data'] = self.cfg['decoding']['results_path'] = str(self.tmp_dir_path.resolve())
        self.cfg['dataset']['gen_subset'] = f'{lang}:dev'
        if device is None:
            if torch.cuda.is_available():
                device = 'cuda'
            else:
                device = 'cpu'
        if device == 'cuda':
            pass  # default
        elif device == 'cpu':
            self.cfg['common']['cpu'] = True
        if device == 'tpu':
            self.cfg['common']['tpu'] = True
        cfg = OmegaConf.structured(self.cfg)

    self.wer = hydra_main(cfg)
    # get results: will just read from hypo.word as I don't want to change fairseq repo to get the hypo array
    hypo_file = self.tmp_dir_path / constants.HYPO_WORDS_FILE
    res = []
    with open(hypo_file) as hw:
        hypos = hw.readlines()
        outputs = self._reorder_decode(hypos)
        transcripts = [line[1].strip() for line in outputs]
    if align:
        align_model = AlignmentModel()
        for i in range(len(transcripts)):
            media_file = processed_files[i][0]
            transcript = easymms_utils.get_transcript_segments(transcripts[i], timestamps_type, max_segment_len=max_segment_len)
            segments = align_model.align(media_file=media_file,
                                         transcript=transcript,
                                         lang=lang,
                                         device=device)
            res.append(segments)
    else:
        res = transcripts

    os.chdir(cwd)
    return res
get_supported_langs staticmethod
get_supported_langs()

Helper function to get supported ISO 693-3 languages by the ASR model Source https://dl.fbaipublicfiles.com/mms/misc/language_coverage_mms.html

Returns:

Type Description
List[str]

list of supported languages

Source code in easymms/models/asr.py
214
215
216
217
218
219
220
221
222
223
@staticmethod
def get_supported_langs() -> List[str]:
    """
    Helper function to get supported ISO 693-3 languages by the ASR model
    Source <https://dl.fbaipublicfiles.com/mms/misc/language_coverage_mms.html>
    :return: list of supported languages
    """
    with open(constants.MMS_LANGS_FILE) as f:
        data = json.load(f)
        return [key for key in data if data[key]['ASR']]

easymms.constants

Constants

PACKAGE_NAME module-attribute

PACKAGE_NAME = 'easymms'

LOGGING_LEVEL module-attribute

LOGGING_LEVEL = logging.INFO

PACKAGE_DATA_DIR module-attribute

PACKAGE_DATA_DIR = user_data_dir(PACKAGE_NAME)

TTS_DIR module-attribute

TTS_DIR = Path(PACKAGE_DATA_DIR) / 'tts'.resolve()

TTS_MODELS_BASE_URL module-attribute

TTS_MODELS_BASE_URL = (
    "https://dl.fbaipublicfiles.com/mms/tts/"
)

VITS_URL module-attribute

VITS_URL = 'https://github.com/jaywalnut310/vits'

VITS_DIR module-attribute

VITS_DIR = TTS_DIR / 'vits'

FAIRSEQ_URL module-attribute

FAIRSEQ_URL = 'https://github.com/facebookresearch/fairseq'

FAIRSEQ_DIR module-attribute

FAIRSEQ_DIR = Path(PACKAGE_DATA_DIR) / 'fairseq'

CFG module-attribute

CFG = {
    "_name": None,
    "task": {
        "_name": "audio_finetuning",
        "data": "",
        "labels": "ltr",
    },
    "decoding": {
        "_name": None,
        "nbest": 1,
        "unitlm": False,
        "lmpath": "???",
        "lexicon": None,
        "beam": 50,
        "beamthreshold": 50.0,
        "beamsizetoken": None,
        "wordscore": -1.0,
        "unkweight": -np.inf,
        "silweight": 0.0,
        "lmweight": 2.0,
        "type": "viterbi",
        "unique_wer_file": False,
        "results_path": "",
    },
    "common": {
        "_name": None,
        "no_progress_bar": False,
        "log_interval": 100,
        "log_format": None,
        "log_file": None,
        "aim_repo": None,
        "aim_run_hash": None,
        "tensorboard_logdir": None,
        "wandb_project": None,
        "azureml_logging": False,
        "seed": 1,
        "cpu": False,
        "tpu": False,
        "bf16": False,
        "memory_efficient_bf16": False,
        "fp16": False,
        "memory_efficient_fp16": False,
        "fp16_no_flatten_grads": False,
        "fp16_init_scale": 128,
        "fp16_scale_window": None,
        "fp16_scale_tolerance": 0.0,
        "on_cpu_convert_precision": False,
        "min_loss_scale": 0.0001,
        "threshold_loss_scale": None,
        "amp": False,
        "amp_batch_retries": 2,
        "amp_init_scale": 128,
        "amp_scale_window": None,
        "user_dir": None,
        "empty_cache_freq": 0,
        "all_gather_list_size": 16384,
        "model_parallel_size": 1,
        "quantization_config_path": None,
        "profile": False,
        "reset_logging": False,
        "suppress_crashes": False,
        "use_plasma_view": False,
        "plasma_path": "/tmp/plasma",
    },
    "common_eval": {
        "_name": None,
        "path": "",
        "post_process": "letter",
        "quiet": False,
        "model_overrides": "{}",
        "results_path": None,
    },
    "checkpoint": {
        "_name": None,
        "save_dir": "checkpoints",
        "restore_file": "checkpoint_last.pt",
        "continue_once": None,
        "finetune_from_model": None,
        "reset_dataloader": False,
        "reset_lr_scheduler": False,
        "reset_meters": False,
        "reset_optimizer": False,
        "optimizer_overrides": "{}",
        "save_interval": 1,
        "save_interval_updates": 0,
        "keep_interval_updates": -1,
        "keep_interval_updates_pattern": -1,
        "keep_last_epochs": -1,
        "keep_best_checkpoints": -1,
        "no_save": False,
        "no_epoch_checkpoints": False,
        "no_last_checkpoints": False,
        "no_save_optimizer_state": False,
        "best_checkpoint_metric": "loss",
        "maximize_best_checkpoint_metric": False,
        "patience": -1,
        "checkpoint_suffix": "",
        "checkpoint_shard_count": 1,
        "load_checkpoint_on_all_dp_ranks": False,
        "write_checkpoints_asynchronously": False,
        "model_parallel_size": 1,
    },
    "distributed_training": {
        "_name": None,
        "distributed_world_size": 1,
        "distributed_num_procs": 1,
        "distributed_rank": 0,
        "distributed_backend": "nccl",
        "distributed_init_method": None,
        "distributed_port": -1,
        "device_id": 0,
        "distributed_no_spawn": False,
        "ddp_backend": "legacy_ddp",
        "ddp_comm_hook": "none",
        "bucket_cap_mb": 25,
        "fix_batches_to_gpus": False,
        "find_unused_parameters": False,
        "gradient_as_bucket_view": False,
        "fast_stat_sync": False,
        "heartbeat_timeout": -1,
        "broadcast_buffers": False,
        "slowmo_momentum": None,
        "slowmo_base_algorithm": "localsgd",
        "localsgd_frequency": 3,
        "nprocs_per_node": 1,
        "pipeline_model_parallel": False,
        "pipeline_balance": None,
        "pipeline_devices": None,
        "pipeline_chunks": 0,
        "pipeline_encoder_balance": None,
        "pipeline_encoder_devices": None,
        "pipeline_decoder_balance": None,
        "pipeline_decoder_devices": None,
        "pipeline_checkpoint": "never",
        "zero_sharding": "none",
        "fp16": False,
        "memory_efficient_fp16": True,
        "tpu": False,
        "no_reshard_after_forward": False,
        "fp32_reduce_scatter": False,
        "cpu_offload": False,
        "use_sharded_state": False,
        "not_fsdp_flatten_parameters": False,
    },
    "dataset": {
        "_name": None,
        "num_workers": 1,
        "skip_invalid_size_inputs_valid_test": False,
        "max_tokens": 4000000,
        "batch_size": None,
        "required_batch_size_multiple": 1,
        "required_seq_len_multiple": 1,
        "dataset_impl": None,
        "data_buffer_size": 10,
        "train_subset": "train",
        "valid_subset": "valid",
        "combine_valid_subsets": None,
        "ignore_unused_valid_subsets": False,
        "validate_interval": 1,
        "validate_interval_updates": 0,
        "validate_after_updates": 0,
        "fixed_validation_seed": None,
        "disable_validation": False,
        "max_tokens_valid": 4000000,
        "batch_size_valid": None,
        "max_valid_steps": None,
        "curriculum": 0,
        "gen_subset": "eng:dev",
        "num_shards": 1,
        "shard_id": 0,
        "grouped_shuffling": False,
        "update_epoch_batch_itr": False,
        "update_ordered_indices_seed": False,
    },
    "is_ax": False,
}

HYPO_WORDS_FILE module-attribute

HYPO_WORDS_FILE = 'hypo.word'

ALIGNMENT_MODEL_URL module-attribute

ALIGNMENT_MODEL_URL = "https://dl.fbaipublicfiles.com/mms/torchaudio/ctc_alignment_mling_uroman/model.pt"

ALIGNMENT_DICTIONARY_URL module-attribute

ALIGNMENT_DICTIONARY_URL = "https://dl.fbaipublicfiles.com/mms/torchaudio/ctc_alignment_mling_uroman/dictionary.txt"

UROMAN_URL module-attribute

UROMAN_URL = 'https://github.com/isi-nlp/uroman'

UROMAN_DIR module-attribute

UROMAN_DIR = Path(PACKAGE_DATA_DIR) / 'uroman'

MMS_LANGS_FILE module-attribute

MMS_LANGS_FILE = (
    Path(__file__).parent
    / "data"
    / "mms_langs.json".resolve()
)

easymms.utils

Utils functions

download_file

download_file(url, download_dir=None, chunk_size=1024)

Helper function to download models and other required files

Parameters:

Name Type Description Default
url str

URL of the file

required
download_dir

Where to store the file

None
chunk_size

size of the download chunk

1024

Returns:

Type Description
str

Absolute path of the downloaded model

Source code in easymms/utils.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def download_file(url: str, download_dir=None, chunk_size=1024) -> str:
    """
    Helper function to download models and other required files
    :param url: URL of the file
    :param download_dir: Where to store the file
    :param chunk_size: size of the download chunk

    :return: Absolute path of the downloaded model
    """

    os.makedirs(download_dir, exist_ok=True)
    file_name = os.path.basename(url)
    file_path = Path(download_dir) / file_name
    # check if the file is already there
    if file_path.exists():
        logging.info(f"File '{file_name}' already exists in {download_dir}")
    else:
        # download it
        resp = requests.get(url, stream=True)
        total = int(resp.headers.get('content-length', 0))

        progress_bar = tqdm(desc=f"Downloading File {file_name} ...",
                            total=total,
                            unit='iB',
                            unit_scale=True,
                            unit_divisor=1024)

        try:
            with open(file_path, 'wb') as file, progress_bar:
                for data in resp.iter_content(chunk_size=chunk_size):
                    size = file.write(data)
                    progress_bar.update(size)
            logging.info(f"Model downloaded to {file_path.absolute()}")
        except Exception as e:
            # error download, just remove the file
            os.remove(file_path)
            raise e
    return str(file_path.absolute())

download_and_extract

download_and_extract(url, extract_to='.')

Downloads and unzips a zip folder Will be used to download uroman source code

Parameters:

Name Type Description Default
url

the file URL

required
extract_to

extract path

'.'

Returns:

Type Description

None

Source code in easymms/utils.py
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
def download_and_extract(url, extract_to='.'):
    """
    Downloads and unzips a zip folder
    Will be used to download uroman source code
    :param url: the file URL
    :param extract_to: extract path
    :return: None
    """
    logger.info(f"Downloading file '{url}' and extracting to '{extract_to}' ...")
    http_response = urlopen(url)
    file_name = os.path.basename(url)
    if file_name.endswith('.zip'):
        zipfile = ZipFile(BytesIO(http_response.read()))
        zipfile.extractall(path=extract_to)
    else:
        tar = tarfile.open(BytesIO(http_response.read()), "r:gz")
        tar.extractall(path=extract_to)
        tar.close()

get_lang_info

get_lang_info(lang)

Returns more info about a language,

Parameters:

Name Type Description Default
lang str

the ISO 693-3 language code

required

Returns:

Type Description
dict

dict of info

Source code in easymms/utils.py
 97
 98
 99
100
101
102
103
104
105
def get_lang_info(lang: str) -> dict:
    """
    Returns more info about a language,
    :param lang: the ISO 693-3 language code
    :return: dict of info
    """
    with open(constants.MMS_LANGS_FILE) as f:
        data = json.load(f)
        return data[lang]

get_transcript_segments

get_transcript_segments(
    transcript, t_type="segment", max_segment_len=24
)

A helper function to fragment the transcript to segments

Parameters:

Name Type Description Default
transcript str

results of the ASR model

required
t_type str

one of [word, segment, char]

'segment'
max_segment_len int

the maximum length of the segment

24

Returns:

Type Description

list of segments

Source code in easymms/utils.py
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
def get_transcript_segments(transcript: str, t_type: str = 'segment', max_segment_len: int = 24):
    """
    A helper function to fragment the transcript to segments
    <Quick implementation, Not perfect (barely works), needs improvements>

    :param transcript: results of the ASR model
    :param t_type: one of [`word`, `segment`, `char`]
    :param max_segment_len: the maximum length of the segment
    :return: list of segments
    """
    res = []
    if t_type == 'word':
        res = transcript.split()
    elif t_type == 'char':
        s = ''
        for c in transcript:
            if c == ' ':
                continue
            if len(s) >= max_segment_len:
                res.append(s)
                s = c
            else:
                s += c
        res.append(s)
    else:
        s = ''
        words = transcript.strip().split()
        for word in words:
            new_word = s + ' ' + word
            if len(new_word) > max_segment_len:
                res.append(s.strip())
                s = word
            else:
                s = new_word
        res.append(s)

    return res