00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043 #include "ckd_alloc.h"
00044 #include "ngram_model_arpa.h"
00045 #include "err.h"
00046 #include "pio.h"
00047 #include "listelem_alloc.h"
00048 #include "strfuncs.h"
00049
00050 #include <string.h>
00051 #include <limits.h>
00052 #include <assert.h>
00053
00054 static ngram_funcs_t ngram_model_arpa_funcs;
00055
00056 #define TSEG_BASE(m,b) ((m)->lm3g.tseg_base[(b)>>LOG_BG_SEG_SZ])
00057 #define FIRST_BG(m,u) ((m)->lm3g.unigrams[u].bigrams)
00058 #define FIRST_TG(m,b) (TSEG_BASE((m),(b))+((m)->lm3g.bigrams[b].trigrams))
00059
00060
00061
00062
00063 static int
00064 ReadNgramCounts(lineiter_t **li, int32 * n_ug, int32 * n_bg, int32 * n_tg)
00065 {
00066 int32 ngram, ngram_cnt;
00067
00068
00069 while (*li) {
00070 string_trim((*li)->buf, STRING_BOTH);
00071 if (strcmp((*li)->buf, "\\data\\") == 0)
00072 break;
00073 *li = lineiter_next(*li);
00074 }
00075 if (*li == NULL || strcmp((*li)->buf, "\\data\\") != 0) {
00076 E_ERROR("No \\data\\ mark in LM file\n");
00077 return -1;
00078 }
00079
00080 *n_ug = *n_bg = *n_tg = 0;
00081 while ((*li = lineiter_next(*li))) {
00082 if (sscanf((*li)->buf, "ngram %d=%d", &ngram, &ngram_cnt) != 2)
00083 break;
00084 switch (ngram) {
00085 case 1:
00086 *n_ug = ngram_cnt;
00087 break;
00088 case 2:
00089 *n_bg = ngram_cnt;
00090 break;
00091 case 3:
00092 *n_tg = ngram_cnt;
00093 break;
00094 default:
00095 E_ERROR("Unknown ngram (%d)\n", ngram);
00096 return -1;
00097 }
00098 }
00099 if (*li == NULL) {
00100 E_ERROR("EOF while reading ngram counts\n");
00101 return -1;
00102 }
00103
00104
00105 while ((*li = lineiter_next(*li))) {
00106 string_trim((*li)->buf, STRING_BOTH);
00107 if (strcmp((*li)->buf, "\\1-grams:") == 0)
00108 break;
00109 }
00110 if (*li == NULL) {
00111 E_ERROR_SYSTEM("Failed to read \\1-grams: mark");
00112 return -1;
00113 }
00114
00115
00116 if ((*n_ug <= 0) || (*n_bg <= 0) || (*n_tg < 0)) {
00117 E_ERROR("Bad or missing ngram count\n");
00118 return -1;
00119 }
00120 return 0;
00121 }
00122
00123
00124
00125
00126
00127
00128 static int
00129 ReadUnigrams(lineiter_t **li, ngram_model_arpa_t * model)
00130 {
00131 ngram_model_t *base = &model->base;
00132 int32 wcnt;
00133 float p1;
00134
00135 E_INFO("Reading unigrams\n");
00136
00137 wcnt = 0;
00138 while ((*li = lineiter_next(*li))) {
00139 char *wptr[3], *name;
00140 float32 bo_wt = 0.0f;
00141 int n;
00142
00143 string_trim((*li)->buf, STRING_BOTH);
00144 if (strcmp((*li)->buf, "\\2-grams:") == 0)
00145 break;
00146
00147 if ((n = str2words((*li)->buf, wptr, 3)) < 2) {
00148 if ((*li)->buf[0] != '\0')
00149 E_WARN("Format error; unigram ignored: %s\n", (*li)->buf);
00150 continue;
00151 }
00152 else {
00153 p1 = (float)atof_c(wptr[0]);
00154 name = wptr[1];
00155 if (n == 3)
00156 bo_wt = (float)atof_c(wptr[2]);
00157 }
00158
00159 if (wcnt >= base->n_counts[0]) {
00160 E_ERROR("Too many unigrams\n");
00161 return -1;
00162 }
00163
00164
00165 base->word_str[wcnt] = ckd_salloc(name);
00166 if ((hash_table_enter(base->wid, base->word_str[wcnt], (void *)(long)wcnt))
00167 != (void *)(long)wcnt) {
00168 E_WARN("Duplicate word in dictionary: %s\n", base->word_str[wcnt]);
00169 }
00170 model->lm3g.unigrams[wcnt].prob1.l = logmath_log10_to_log(base->lmath, p1);
00171 model->lm3g.unigrams[wcnt].bo_wt1.l = logmath_log10_to_log(base->lmath, bo_wt);
00172 wcnt++;
00173 }
00174
00175 if (base->n_counts[0] != wcnt) {
00176 E_WARN("lm_t.ucount(%d) != #unigrams read(%d)\n",
00177 base->n_counts[0], wcnt);
00178 base->n_counts[0] = wcnt;
00179 base->n_words = wcnt;
00180 }
00181 return 0;
00182 }
00183
00184
00185
00186
00187 static int
00188 ReadBigrams(lineiter_t **li, ngram_model_arpa_t * model)
00189 {
00190 ngram_model_t *base = &model->base;
00191 int32 w1, w2, prev_w1, bgcount;
00192 bigram_t *bgptr;
00193
00194 E_INFO("Reading bigrams\n");
00195
00196 bgcount = 0;
00197 bgptr = model->lm3g.bigrams;
00198 prev_w1 = -1;
00199
00200 while ((*li = lineiter_next(*li))) {
00201 float32 p, bo_wt = 0.0f;
00202 int32 p2, bo_wt2;
00203 char *wptr[4], *word1, *word2;
00204 int n;
00205
00206 string_trim((*li)->buf, STRING_BOTH);
00207 wptr[3] = NULL;
00208 if ((n = str2words((*li)->buf, wptr, 4)) < 3) {
00209 if ((*li)->buf[0] != '\0')
00210 break;
00211 continue;
00212 }
00213 else {
00214 p = (float32)atof_c(wptr[0]);
00215 word1 = wptr[1];
00216 word2 = wptr[2];
00217 if (wptr[3])
00218 bo_wt = (float32)atof_c(wptr[3]);
00219 }
00220
00221 if ((w1 = ngram_wid(base, word1)) == NGRAM_INVALID_WID) {
00222 E_ERROR("Unknown word: %s, skipping bigram (%s %s)\n",
00223 word1, word1, word2);
00224 continue;
00225 }
00226 if ((w2 = ngram_wid(base, word2)) == NGRAM_INVALID_WID) {
00227 E_ERROR("Unknown word: %s, skipping bigram (%s %s)\n",
00228 word2, word1, word2);
00229 continue;
00230 }
00231
00232
00233
00234 p = (float32)((int32)(p * 10000)) / 10000;
00235 bo_wt = (float32)((int32)(bo_wt * 10000)) / 10000;
00236
00237 p2 = logmath_log10_to_log(base->lmath, p);
00238 bo_wt2 = logmath_log10_to_log(base->lmath, bo_wt);
00239
00240 if (bgcount >= base->n_counts[1]) {
00241 E_ERROR("Too many bigrams\n");
00242 return -1;
00243 }
00244
00245 bgptr->wid = w2;
00246 bgptr->prob2 = sorted_id(&model->sorted_prob2, &p2);
00247 if (base->n_counts[2] > 0)
00248 bgptr->bo_wt2 = sorted_id(&model->sorted_bo_wt2, &bo_wt2);
00249
00250 if (w1 != prev_w1) {
00251 if (w1 < prev_w1) {
00252 E_ERROR("Bigrams not in unigram order\n");
00253 return -1;
00254 }
00255
00256 for (prev_w1++; prev_w1 <= w1; prev_w1++)
00257 model->lm3g.unigrams[prev_w1].bigrams = bgcount;
00258 prev_w1 = w1;
00259 }
00260 bgcount++;
00261 bgptr++;
00262
00263 if ((bgcount & 0x0000ffff) == 0) {
00264 E_INFOCONT(".");
00265 }
00266 }
00267 if (*li == NULL || ((strcmp((*li)->buf, "\\end\\") != 0)
00268 && (strcmp((*li)->buf, "\\3-grams:") != 0))) {
00269 E_ERROR("Bad bigram: %s\n", (*li)->buf);
00270 return -1;
00271 }
00272
00273 for (prev_w1++; prev_w1 <= base->n_counts[0]; prev_w1++)
00274 model->lm3g.unigrams[prev_w1].bigrams = bgcount;
00275
00276 return 0;
00277 }
00278
00279
00280
00281
00282 static int
00283 ReadTrigrams(lineiter_t **li, ngram_model_arpa_t * model)
00284 {
00285 ngram_model_t *base = &model->base;
00286 int32 i, w1, w2, w3, prev_w1, prev_w2, tgcount, prev_bg, bg, endbg;
00287 int32 seg, prev_seg, prev_seg_lastbg;
00288 trigram_t *tgptr;
00289 bigram_t *bgptr;
00290
00291 E_INFO("Reading trigrams\n");
00292
00293 tgcount = 0;
00294 tgptr = model->lm3g.trigrams;
00295 prev_w1 = -1;
00296 prev_w2 = -1;
00297 prev_bg = -1;
00298 prev_seg = -1;
00299
00300 while ((*li = lineiter_next(*li))) {
00301 float32 p;
00302 int32 p3;
00303 char *wptr[4], *word1, *word2, *word3;
00304
00305 string_trim((*li)->buf, STRING_BOTH);
00306 if (str2words((*li)->buf, wptr, 4) != 4) {
00307 if ((*li)->buf[0] != '\0')
00308 break;
00309 continue;
00310 }
00311 else {
00312 p = (float32)atof_c(wptr[0]);
00313 word1 = wptr[1];
00314 word2 = wptr[2];
00315 word3 = wptr[3];
00316 }
00317
00318 if ((w1 = ngram_wid(base, word1)) == NGRAM_INVALID_WID) {
00319 E_ERROR("Unknown word: %s, skipping trigram (%s %s %s)\n",
00320 word1, word1, word2, word3);
00321 continue;
00322 }
00323 if ((w2 = ngram_wid(base, word2)) == NGRAM_INVALID_WID) {
00324 E_ERROR("Unknown word: %s, skipping trigram (%s %s %s)\n",
00325 word2, word1, word2, word3);
00326 continue;
00327 }
00328 if ((w3 = ngram_wid(base, word3)) == NGRAM_INVALID_WID) {
00329 E_ERROR("Unknown word: %s, skipping trigram (%s %s %s)\n",
00330 word3, word1, word2, word3);
00331 continue;
00332 }
00333
00334
00335
00336 p = (float32)((int32)(p * 10000)) / 10000;
00337 p3 = logmath_log10_to_log(base->lmath, p);
00338
00339 if (tgcount >= base->n_counts[2]) {
00340 E_ERROR("Too many trigrams\n");
00341 return -1;
00342 }
00343
00344 tgptr->wid = w3;
00345 tgptr->prob3 = sorted_id(&model->sorted_prob3, &p3);
00346
00347 if ((w1 != prev_w1) || (w2 != prev_w2)) {
00348
00349 if ((w1 < prev_w1) || ((w1 == prev_w1) && (w2 < prev_w2))) {
00350 E_ERROR("Trigrams not in bigram order\n");
00351 return -1;
00352 }
00353
00354 bg = (w1 !=
00355 prev_w1) ? model->lm3g.unigrams[w1].bigrams : prev_bg + 1;
00356 endbg = model->lm3g.unigrams[w1 + 1].bigrams;
00357 bgptr = model->lm3g.bigrams + bg;
00358 for (; (bg < endbg) && (bgptr->wid != w2); bg++, bgptr++);
00359 if (bg >= endbg) {
00360 E_ERROR("Missing bigram for trigram: %s", (*li)->buf);
00361 return -1;
00362 }
00363
00364
00365 seg = bg >> LOG_BG_SEG_SZ;
00366 for (i = prev_seg + 1; i <= seg; i++)
00367 model->lm3g.tseg_base[i] = tgcount;
00368
00369
00370 if (prev_seg < seg) {
00371 int32 tgoff = 0;
00372
00373 if (prev_seg >= 0) {
00374 tgoff = tgcount - model->lm3g.tseg_base[prev_seg];
00375 if (tgoff > 65535) {
00376 E_ERROR("Offset from tseg_base > 65535\n");
00377 return -1;
00378 }
00379 }
00380
00381 prev_seg_lastbg = ((prev_seg + 1) << LOG_BG_SEG_SZ) - 1;
00382 bgptr = model->lm3g.bigrams + prev_bg;
00383 for (++prev_bg, ++bgptr; prev_bg <= prev_seg_lastbg;
00384 prev_bg++, bgptr++)
00385 bgptr->trigrams = tgoff;
00386
00387 for (; prev_bg <= bg; prev_bg++, bgptr++)
00388 bgptr->trigrams = 0;
00389 }
00390 else {
00391 int32 tgoff;
00392
00393 tgoff = tgcount - model->lm3g.tseg_base[prev_seg];
00394 if (tgoff > 65535) {
00395 E_ERROR("Offset from tseg_base > 65535\n");
00396 return -1;
00397 }
00398
00399 bgptr = model->lm3g.bigrams + prev_bg;
00400 for (++prev_bg, ++bgptr; prev_bg <= bg; prev_bg++, bgptr++)
00401 bgptr->trigrams = tgoff;
00402 }
00403
00404 prev_w1 = w1;
00405 prev_w2 = w2;
00406 prev_bg = bg;
00407 prev_seg = seg;
00408 }
00409
00410 tgcount++;
00411 tgptr++;
00412
00413 if ((tgcount & 0x0000ffff) == 0) {
00414 E_INFOCONT(".");
00415 }
00416 }
00417 if (*li == NULL || strcmp((*li)->buf, "\\end\\") != 0) {
00418 E_ERROR("Bad trigram: %s\n", (*li)->buf);
00419 return -1;
00420 }
00421
00422 for (prev_bg++; prev_bg <= base->n_counts[1]; prev_bg++) {
00423 if ((prev_bg & (BG_SEG_SZ - 1)) == 0)
00424 model->lm3g.tseg_base[prev_bg >> LOG_BG_SEG_SZ] = tgcount;
00425 if ((tgcount - model->lm3g.tseg_base[prev_bg >> LOG_BG_SEG_SZ]) > 65535) {
00426 E_ERROR("Offset from tseg_base > 65535\n");
00427 return -1;
00428 }
00429 model->lm3g.bigrams[prev_bg].trigrams =
00430 tgcount - model->lm3g.tseg_base[prev_bg >> LOG_BG_SEG_SZ];
00431 }
00432 return 0;
00433 }
00434
00435 static unigram_t *
00436 new_unigram_table(int32 n_ug)
00437 {
00438 unigram_t *table;
00439 int32 i;
00440
00441 table = ckd_calloc(n_ug, sizeof(unigram_t));
00442 for (i = 0; i < n_ug; i++) {
00443 table[i].prob1.l = INT_MIN;
00444 table[i].bo_wt1.l = INT_MIN;
00445 }
00446 return table;
00447 }
00448
00449 ngram_model_t *
00450 ngram_model_arpa_read(cmd_ln_t *config,
00451 const char *file_name,
00452 logmath_t *lmath)
00453 {
00454 lineiter_t *li;
00455 FILE *fp;
00456 int32 is_pipe;
00457 int32 n_unigram;
00458 int32 n_bigram;
00459 int32 n_trigram;
00460 int32 n;
00461 ngram_model_arpa_t *model;
00462 ngram_model_t *base;
00463
00464 if ((fp = fopen_comp(file_name, "r", &is_pipe)) == NULL) {
00465 E_ERROR("File %s not found\n", file_name);
00466 return NULL;
00467 }
00468 li = lineiter_start(fp);
00469
00470
00471 if (ReadNgramCounts(&li, &n_unigram, &n_bigram, &n_trigram) == -1) {
00472 lineiter_free(li);
00473 fclose_comp(fp, is_pipe);
00474 return NULL;
00475 }
00476 E_INFO("ngrams 1=%d, 2=%d, 3=%d\n", n_unigram, n_bigram, n_trigram);
00477
00478
00479 model = ckd_calloc(1, sizeof(*model));
00480 base = &model->base;
00481 if (n_trigram > 0)
00482 n = 3;
00483 else if (n_bigram > 0)
00484 n = 2;
00485 else
00486 n = 1;
00487
00488 ngram_model_init(base, &ngram_model_arpa_funcs, lmath, n, n_unigram);
00489 base->n_counts[0] = n_unigram;
00490 base->n_counts[1] = n_bigram;
00491 base->n_counts[2] = n_trigram;
00492 base->writable = TRUE;
00493
00494
00495
00496
00497
00498 model->lm3g.unigrams = new_unigram_table(n_unigram + 1);
00499 model->lm3g.bigrams =
00500 ckd_calloc(n_bigram + 1, sizeof(bigram_t));
00501 if (n_trigram > 0)
00502 model->lm3g.trigrams =
00503 ckd_calloc(n_trigram, sizeof(trigram_t));
00504
00505 if (n_trigram > 0) {
00506 model->lm3g.tseg_base =
00507 ckd_calloc((n_bigram + 1) / BG_SEG_SZ + 1,
00508 sizeof(int32));
00509 }
00510 if (ReadUnigrams(&li, model) == -1) {
00511 fclose_comp(fp, is_pipe);
00512 ngram_model_free(base);
00513 return NULL;
00514 }
00515 E_INFO("%8d = #unigrams created\n", base->n_counts[0]);
00516
00517 init_sorted_list(&model->sorted_prob2);
00518 if (base->n_counts[2] > 0)
00519 init_sorted_list(&model->sorted_bo_wt2);
00520
00521 if (ReadBigrams(&li, model) == -1) {
00522 fclose_comp(fp, is_pipe);
00523 ngram_model_free(base);
00524 return NULL;
00525 }
00526
00527 base->n_counts[1] = FIRST_BG(model, base->n_counts[0]);
00528 model->lm3g.n_prob2 = model->sorted_prob2.free;
00529 model->lm3g.prob2 = vals_in_sorted_list(&model->sorted_prob2);
00530 free_sorted_list(&model->sorted_prob2);
00531 E_INFO("%8d = #bigrams created\n", base->n_counts[1]);
00532 E_INFO("%8d = #prob2 entries\n", model->lm3g.n_prob2);
00533
00534 if (base->n_counts[2] > 0) {
00535
00536 model->lm3g.n_bo_wt2 = model->sorted_bo_wt2.free;
00537 model->lm3g.bo_wt2 = vals_in_sorted_list(&model->sorted_bo_wt2);
00538 free_sorted_list(&model->sorted_bo_wt2);
00539 E_INFO("%8d = #bo_wt2 entries\n", model->lm3g.n_bo_wt2);
00540
00541 init_sorted_list(&model->sorted_prob3);
00542
00543 if (ReadTrigrams(&li, model) == -1) {
00544 fclose_comp(fp, is_pipe);
00545 ngram_model_free(base);
00546 return NULL;
00547 }
00548
00549 base->n_counts[2] = FIRST_TG(model, base->n_counts[1]);
00550 model->lm3g.n_prob3 = model->sorted_prob3.free;
00551 model->lm3g.prob3 = vals_in_sorted_list(&model->sorted_prob3);
00552 E_INFO("%8d = #trigrams created\n", base->n_counts[2]);
00553 E_INFO("%8d = #prob3 entries\n", model->lm3g.n_prob3);
00554
00555 free_sorted_list(&model->sorted_prob3);
00556
00557
00558 model->lm3g.tginfo = ckd_calloc(n_unigram, sizeof(tginfo_t *));
00559 model->lm3g.le = listelem_alloc_init(sizeof(tginfo_t));
00560 }
00561
00562 lineiter_free(li);
00563 fclose_comp(fp, is_pipe);
00564 return base;
00565 }
00566
00567 int
00568 ngram_model_arpa_write(ngram_model_t *model,
00569 const char *file_name)
00570 {
00571 ngram_iter_t *itor;
00572 FILE *fh;
00573 int i;
00574
00575 if ((fh = fopen(file_name, "w")) == NULL) {
00576 E_ERROR_SYSTEM("Failed to open %s for writing", file_name);
00577 return -1;
00578 }
00579 fprintf(fh, "This is an ARPA-format language model file, generated by CMU Sphinx\n");
00580
00581
00582
00583
00584
00585
00586 fprintf(fh, "\\data\\\n");
00587 for (i = 0; i < model->n; ++i) {
00588 fprintf(fh, "ngram %d=%d\n", i+1, model->n_counts[i]);
00589 }
00590
00591
00592 for (i = 0; i < model->n; ++i) {
00593 fprintf(fh, "\n\\%d-grams:\n", i + 1);
00594 for (itor = ngram_model_mgrams(model, i); itor; itor = ngram_iter_next(itor)) {
00595 int32 const *wids;
00596 int32 score, bowt;
00597 int j;
00598
00599 wids = ngram_iter_get(itor, &score, &bowt);
00600 fprintf(fh, "%.4f ", logmath_log_to_log10(model->lmath, score));
00601 for (j = 0; j <= i; ++j) {
00602 assert(wids[j] < model->n_counts[0]);
00603 fprintf(fh, "%s ", model->word_str[wids[j]]);
00604 }
00605 if (i < model->n-1)
00606 fprintf(fh, "%.4f", logmath_log_to_log10(model->lmath, bowt));
00607 fprintf(fh, "\n");
00608 }
00609 }
00610 fprintf(fh, "\n\\end\\\n");
00611 return fclose(fh);
00612 }
00613
00614 static int
00615 ngram_model_arpa_apply_weights(ngram_model_t *base, float32 lw,
00616 float32 wip, float32 uw)
00617 {
00618 ngram_model_arpa_t *model = (ngram_model_arpa_t *)base;
00619 lm3g_apply_weights(base, &model->lm3g, lw, wip, uw);
00620 return 0;
00621 }
00622
00623
00624
00625
00626 #define NGRAM_MODEL_TYPE ngram_model_arpa_t
00627 #include "lm3g_templates.c"
00628
00629 static void
00630 ngram_model_arpa_free(ngram_model_t *base)
00631 {
00632 ngram_model_arpa_t *model = (ngram_model_arpa_t *)base;
00633 ckd_free(model->lm3g.unigrams);
00634 ckd_free(model->lm3g.bigrams);
00635 ckd_free(model->lm3g.trigrams);
00636 ckd_free(model->lm3g.prob2);
00637 ckd_free(model->lm3g.bo_wt2);
00638 ckd_free(model->lm3g.prob3);
00639 lm3g_tginfo_free(base, &model->lm3g);
00640 ckd_free(model->lm3g.tseg_base);
00641 }
00642
00643 static ngram_funcs_t ngram_model_arpa_funcs = {
00644 ngram_model_arpa_free,
00645 ngram_model_arpa_apply_weights,
00646 lm3g_template_score,
00647 lm3g_template_raw_score,
00648 lm3g_template_add_ug,
00649 lm3g_template_flush,
00650 lm3g_template_iter,
00651 lm3g_template_mgrams,
00652 lm3g_template_successors,
00653 lm3g_template_iter_get,
00654 lm3g_template_iter_next,
00655 lm3g_template_iter_free
00656 };