2013년 1월 10일 목요일

특정 코드를 병렬로 수행하기 - Java

Scala 는 scala.collection.parallel._ 에 ParArray 등 병렬 처리를 지원하는 자료구조가 지원됩니다만, Java 7 까지는 직접적으로 지원하지는 않습니다.

그래서 Java 로 유사하게 Data 처리를 병렬로 수행할 수 있도록 하는 코드를 구현해 봤습니다.

private static int getPartitionSize(int itemCount, int partitionCount) {
    return (itemCount / partitionCount) + ((itemCount % partitionCount) > 0 ? 1 : 0);
}

public static <T, V> List<V> run(final Iterable<T> elements, final Function1<T, V> function) {
shouldNotBeNull(elements, "elements");
shouldNotBeNull(function, "function");
ExecutorService executor = Executors.newFixedThreadPool(getProcessCount());
if (log.isDebugEnabled())
log.debug("작업을 병렬로 수행합니다. 작업 스레드 수=[{}]", getProcessCount());
try {
List<T> elemList = Lists.newArrayList(elements);
int partitionSize = getPartitionSize(elemList.size(), getProcessCount());
List<List<T>> partitions = Lists.partition(elemList, partitionSize);
final Map<Integer, List<V>> localResults = Maps.newLinkedHashMap();
List<Callable<List<V>>> tasks = Lists.newLinkedList(); // False Sharing을 방지하기 위해
for (int p = 0; p < partitions.size(); p++) {
final List<T> partition = partitions.get(p);
final List<V> localResult = Lists.newArrayListWithCapacity(partition.size());
localResults.put(p, localResult);
Callable<List<V>> task = new Callable<List<V>>() {
@Override
public List<V> call() throws Exception {
for (final T element : partition)
localResult.add(function.execute(element));
return localResult;
}
};
tasks.add(task);
}
executor.invokeAll(tasks);
List<V> results = Lists.newArrayListWithCapacity(elemList.size());
for (int i = 0; i < partitions.size(); i++) {
results.addAll(localResults.get(i));
}
if (log.isDebugEnabled())
log.debug("모든 작업을 병렬로 완료했습니다. partition size=[{}]", partitions.size());
return results;
} catch (Exception e) {
log.error("데이터에 대한 병렬 작업 중 예외가 발생했습니다.", e);
throw new RuntimeException(e);
} finally {
executor.shutdown();
}
}
view raw Parallels.java hosted with ❤ by GitHub
다음과 같은 절차로 수행됩니다.

  1. ExecutorService 생성 - 논리적 CPU 갯수만큼의 작업스레드를 가지게 한다.
  2. 입력 데이터 컬렉션인 elements 를 Process 갯수로 나눈다. (partitions)
    각 Process 별로 작업할 컬렉션을 분할합니다. ( Process 가 4개이고, elements 수가 100개라면, 0~24 : 0 CPU, 25~49 : 1 CPU ... )
  3. partition 별로 작업을 정의한다. 
  4. 모든 작업을 ExecutorService에게 실행시킨다.
  5. 작업 결과를 취합하여 반환한다.

위의 코드가 아주 제한적인 기능이지만, 입력값 별로 특정 로직을 수행할 때에는 유용합니다. 기본적으로 CPU가 4개라면, 최대 4배까지 빨라집니다. 물론 부가 작업이 있으니 약간은 떨어지겠지요.

만약 집계 기능의 경우는 위의 3, 4 번에서 소계를 수행하는 코드와 마지막 집계하는 코드가 더 필요할 것입니다.

테스트 코드는 다음과 같습니다. 십만번 호출 작업을 하는 테스트 코드 블럭을 100번 반복할 때, CPU 갯수 만큼 나눠서 수행하도록 합니다.

@Slf4j
public class ParallelsTest {
private static final int LowerBound = 0;
private static final int UpperBound = 99999;
@Test
public void parallelRunAction() {
final Action1<Integer> action1 =
new Action1<Integer>() {
@Override
public void perform(Integer x) {
for (int i = LowerBound; i < UpperBound; i++) {
Hero.findRoot(i);
}
if (log.isDebugEnabled())
log.debug("FindRoot({}) returns [{}]", UpperBound, Hero.findRoot(UpperBound));
}
};
@Cleanup
AutoStopwatch stopwatch = new AutoStopwatch();
Parallels.run(Range.range(0, 100), action1);
}
@Test
public void parallelRunFunction() {
final Function1<Integer, Double> function1 =
new Function1<Integer, Double>() {
@Override
public Double execute(Integer x) {
for (int i = LowerBound; i < UpperBound; i++) {
Hero.findRoot(i);
}
if (log.isDebugEnabled())
log.debug("FindRoot({}) returns [{}]", UpperBound, Hero.findRoot(UpperBound));
return Hero.findRoot(UpperBound);
}
};
@Cleanup
AutoStopwatch stopwatch = new AutoStopwatch();
List<Double> results = Parallels.run(Range.range(0, 100), function1);
Assert.assertNotNull(results);
Assert.assertEquals(100, results.size());
}
public static class Hero {
private static final double Tolerance = 1.0e-10;
public static double findRoot(double number) {
double guess = 1.0;
double error = Math.abs(guess * guess - number);
while (error > Tolerance) {
guess = (number / guess + guess) / 2.0;
error = Math.abs(guess * guess - number);
}
return guess;
}
}
}

* 소스 중에 컬렉션 관련 메소드는 Google Guava 13.0 을 사용했습니다.

댓글 없음: