WordCountTest.java 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. package com.aliyun.odps.examples.mr.test;
  2. import com.aliyun.odps.data.Record;
  3. import com.aliyun.odps.data.TableInfo;
  4. import com.aliyun.odps.examples.TestUtil;
  5. import com.aliyun.odps.examples.mr.WordCount;
  6. import com.aliyun.odps.io.Text;
  7. import com.aliyun.odps.mapred.conf.JobConf;
  8. import com.aliyun.odps.mapred.unittest.*;
  9. import com.aliyun.odps.mapred.utils.InputUtils;
  10. import com.aliyun.odps.mapred.utils.OutputUtils;
  11. import com.aliyun.odps.mapred.utils.SchemaUtils;
  12. import junit.framework.Assert;
  13. import org.junit.Test;
  14. import java.io.IOException;
  15. import java.util.List;
  16. public class WordCountTest extends MRUnitTest {
  17. // 定义输入输出表的 schema
  18. private final static String INPUT_SCHEMA = "a:string,b:string";
  19. private final static String OUTPUT_SCHEMA = "k:string,v:bigint";
  20. private JobConf job;
  21. public WordCountTest() throws Exception {
  22. TestUtil.initWarehouse();
  23. // 准备作业配置
  24. job = new JobConf();
  25. job.setMapperClass(WordCount.TokenizerMapper.class);
  26. job.setCombinerClass(WordCount.SumCombiner.class);
  27. job.setReducerClass(WordCount.SumReducer.class);
  28. job.setMapOutputKeySchema(SchemaUtils.fromString("key:string"));
  29. job.setMapOutputValueSchema(SchemaUtils.fromString("value:bigint"));
  30. InputUtils.addTable(TableInfo.builder().tableName("wc_in").build(), job);
  31. OutputUtils.addTable(TableInfo.builder().tableName("wc_out").build(), job);
  32. }
  33. @SuppressWarnings("deprecation")
  34. @Test
  35. public void testMap() throws IOException, ClassNotFoundException, InterruptedException {
  36. MapUTContext mapContext = new MapUTContext();
  37. mapContext.setInputSchema(INPUT_SCHEMA);
  38. mapContext.setOutputSchema(OUTPUT_SCHEMA, job);
  39. // 准备测试数据
  40. Record record = mapContext.createInputRecord();
  41. record.set(new Text[] {new Text("hello"), new Text("c")});
  42. mapContext.addInputRecord(record);
  43. record = mapContext.createInputRecord();
  44. record.set(new Text[] {new Text("hello"), new Text("java")});
  45. mapContext.addInputRecord(record);
  46. // 运行 map 过程
  47. TaskOutput output = runMapper(job, mapContext);
  48. // 验证 map 的结果(执行了combine),为 3 组 key/value 对
  49. List<KeyValue<Record, Record>> kvs = output.getOutputKeyValues();
  50. Assert.assertEquals(3, kvs.size());
  51. Assert.assertEquals(new KeyValue<String, Long>(new String("c"), new Long(1)),
  52. new KeyValue<String, Long>((String) (kvs.get(0).getKey().get(0)), (Long) (kvs.get(0)
  53. .getValue().get(0))));
  54. Assert.assertEquals(new KeyValue<String, Long>(new String("hello"), new Long(2)),
  55. new KeyValue<String, Long>((String) (kvs.get(1).getKey().get(0)), (Long) (kvs.get(1)
  56. .getValue().get(0))));
  57. Assert.assertEquals(new KeyValue<String, Long>(new String("java"), new Long(1)),
  58. new KeyValue<String, Long>((String) (kvs.get(2).getKey().get(0)), (Long) (kvs.get(2)
  59. .getValue().get(0))));
  60. }
  61. @Test
  62. public void testReduce() throws IOException, ClassNotFoundException, InterruptedException {
  63. ReduceUTContext context = new ReduceUTContext();
  64. context.setOutputSchema(OUTPUT_SCHEMA, job);
  65. // 准备测试数据
  66. Record key = context.createInputKeyRecord(job);
  67. Record value = context.createInputValueRecord(job);
  68. key.set(0, "world");
  69. value.set(0, new Long(1));
  70. context.addInputKeyValue(key, value);
  71. key.set(0, "hello");
  72. value.set(0, new Long(1));
  73. context.addInputKeyValue(key, value);
  74. key.set(0, "hello");
  75. value.set(0, new Long(1));
  76. context.addInputKeyValue(key, value);
  77. key.set(0, "odps");
  78. value.set(0, new Long(1));
  79. context.addInputKeyValue(key, value);
  80. // 运行 reduce 过程
  81. TaskOutput output = runReducer(job, context);
  82. // 验证 reduce 结果,为 3 条 record
  83. List<Record> records = output.getOutputRecords();
  84. Assert.assertEquals(3, records.size());
  85. Assert.assertEquals(new String("hello"), records.get(0).get("k"));
  86. Assert.assertEquals(new Long(2), records.get(0).get("v"));
  87. Assert.assertEquals(new String("odps"), records.get(1).get("k"));
  88. Assert.assertEquals(new Long(1), records.get(1).get("v"));
  89. Assert.assertEquals(new String("world"), records.get(2).get("k"));
  90. Assert.assertEquals(new Long(1), records.get(2).get("v"));
  91. }
  92. }