From 5bb55fb7e026378745a5f2c40cbfabd594010bd7 Mon Sep 17 00:00:00 2001
From: James Moger <james.moger@gitblit.com>
Date: Thu, 29 May 2014 11:48:37 -0400
Subject: [PATCH] Fix thread exhaustion in SSH daemon

---
 src/main/java/com/gitblit/utils/WorkQueue.java                          |    6 ++++--
 src/main/java/com/gitblit/transport/ssh/SshDaemon.java                  |    7 ++++---
 src/main/java/com/gitblit/manager/ServicesManager.java                  |   10 +++++++++-
 src/main/java/com/gitblit/transport/ssh/commands/RootDispatcher.java    |    4 +++-
 src/main/java/com/gitblit/transport/ssh/commands/SshCommandFactory.java |    8 ++++----
 src/main/java/com/gitblit/transport/ssh/commands/BaseCommand.java       |   23 +++++++++++++++++------
 src/main/java/com/gitblit/transport/ssh/commands/DispatchCommand.java   |    1 +
 7 files changed, 42 insertions(+), 17 deletions(-)

diff --git a/src/main/java/com/gitblit/manager/ServicesManager.java b/src/main/java/com/gitblit/manager/ServicesManager.java
index e0fc8bb..b1c97ba 100644
--- a/src/main/java/com/gitblit/manager/ServicesManager.java
+++ b/src/main/java/com/gitblit/manager/ServicesManager.java
@@ -47,6 +47,7 @@
 import com.gitblit.utils.IdGenerator;
 import com.gitblit.utils.StringUtils;
 import com.gitblit.utils.TimeUtils;
+import com.gitblit.utils.WorkQueue;
 
 /**
  * Services manager manages long-running services/processes that either have no
@@ -66,6 +67,10 @@
 
 	private final IGitblit gitblit;
 
+	private final IdGenerator idGenerator;
+
+	private final WorkQueue workQueue;
+
 	private FanoutService fanoutService;
 
 	private GitDaemon gitDaemon;
@@ -75,6 +80,8 @@
 	public ServicesManager(IGitblit gitblit) {
 		this.settings = gitblit.getSettings();
 		this.gitblit = gitblit;
+		this.idGenerator = new IdGenerator();
+		this.workQueue = new WorkQueue(idGenerator, 1);
 	}
 
 	@Override
@@ -99,6 +106,7 @@
 		if (sshDaemon != null) {
 			sshDaemon.stop();
 		}
+		workQueue.stop();
 		return this;
 	}
 
@@ -158,7 +166,7 @@
 		String bindInterface = settings.getString(Keys.git.sshBindInterface, "localhost");
 		if (port > 0) {
 			try {
-				sshDaemon = new SshDaemon(gitblit, new IdGenerator());
+				sshDaemon = new SshDaemon(gitblit, workQueue);
 				sshDaemon.start();
 			} catch (IOException e) {
 				sshDaemon = null;
diff --git a/src/main/java/com/gitblit/transport/ssh/SshDaemon.java b/src/main/java/com/gitblit/transport/ssh/SshDaemon.java
index 4d64cfb..7c51290 100644
--- a/src/main/java/com/gitblit/transport/ssh/SshDaemon.java
+++ b/src/main/java/com/gitblit/transport/ssh/SshDaemon.java
@@ -41,9 +41,9 @@
 import com.gitblit.Keys;
 import com.gitblit.manager.IGitblit;
 import com.gitblit.transport.ssh.commands.SshCommandFactory;
-import com.gitblit.utils.IdGenerator;
 import com.gitblit.utils.JnaUtils;
 import com.gitblit.utils.StringUtils;
+import com.gitblit.utils.WorkQueue;
 import com.google.common.io.Files;
 
 /**
@@ -76,8 +76,9 @@
 	 * Construct the Gitblit SSH daemon.
 	 *
 	 * @param gitblit
+	 * @param workQueue
 	 */
-	public SshDaemon(IGitblit gitblit, IdGenerator idGenerator) {
+	public SshDaemon(IGitblit gitblit, WorkQueue workQueue) {
 		this.gitblit = gitblit;
 
 		IStoredSettings settings = gitblit.getSettings();
@@ -126,7 +127,7 @@
 		sshd.setSessionFactory(new SshServerSessionFactory());
 		sshd.setFileSystemFactory(new DisabledFilesystemFactory());
 		sshd.setTcpipForwardingFilter(new NonForwardingFilter());
-		sshd.setCommandFactory(new SshCommandFactory(gitblit, idGenerator));
+		sshd.setCommandFactory(new SshCommandFactory(gitblit, workQueue));
 		sshd.setShellFactory(new WelcomeShell(settings));
 
 		// Set the server id.  This can be queried with:
diff --git a/src/main/java/com/gitblit/transport/ssh/commands/BaseCommand.java b/src/main/java/com/gitblit/transport/ssh/commands/BaseCommand.java
index d996ea9..ab2756d 100644
--- a/src/main/java/com/gitblit/transport/ssh/commands/BaseCommand.java
+++ b/src/main/java/com/gitblit/transport/ssh/commands/BaseCommand.java
@@ -33,12 +33,13 @@
 import org.apache.sshd.server.ExitCallback;
 import org.apache.sshd.server.SessionAware;
 import org.apache.sshd.server.session.ServerSession;
+import org.kohsuke.args4j.Argument;
 import org.kohsuke.args4j.CmdLineException;
+import org.kohsuke.args4j.Option;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import com.gitblit.Keys;
-import com.gitblit.utils.IdGenerator;
 import com.gitblit.utils.StringUtils;
 import com.gitblit.utils.WorkQueue;
 import com.gitblit.utils.WorkQueue.CancelableRunnable;
@@ -80,13 +81,10 @@
 	/** The task, as scheduled on a worker thread. */
 	private final AtomicReference<Future<?>> task;
 
-	private final WorkQueue.Executor executor;
+	private WorkQueue workQueue;
 
 	public BaseCommand() {
 		task = Atomics.newReference();
-		IdGenerator gen = new IdGenerator();
-		WorkQueue w = new WorkQueue(gen);
-		this.executor = w.getDefaultQueue();
 	}
 
 	@Override
@@ -97,6 +95,10 @@
 	@Override
 	public void destroy() {
 		log.debug("destroying " + getClass().getName());
+		Future<?> future = task.getAndSet(null);
+		if (future != null && !future.isDone()) {
+			future.cancel(true);
+		}
 		session = null;
 		ctx = null;
 	}
@@ -110,10 +112,19 @@
 
 	protected void provideStateTo(final BaseCommand cmd) {
 		cmd.setContext(ctx);
+		cmd.setWorkQueue(workQueue);
 		cmd.setInputStream(in);
 		cmd.setOutputStream(out);
 		cmd.setErrorStream(err);
 		cmd.setExitCallback(exit);
+	}
+
+	public WorkQueue getWorkQueue() {
+		return workQueue;
+	}
+
+	public void setWorkQueue(WorkQueue workQueue) {
+		this.workQueue = workQueue;
 	}
 
 	public void setContext(SshCommandContext ctx) {
@@ -467,7 +478,7 @@
 	 */
 	protected void startThread(final CommandRunnable thunk) {
 		final TaskThunk tt = new TaskThunk(thunk);
-		task.set(executor.submit(tt));
+		task.set(workQueue.getDefaultQueue().submit(tt));
 	}
 
 	/** Thrown from {@link CommandRunnable#run()} with client message and code. */
diff --git a/src/main/java/com/gitblit/transport/ssh/commands/DispatchCommand.java b/src/main/java/com/gitblit/transport/ssh/commands/DispatchCommand.java
index 86b3369..d17a4eb 100644
--- a/src/main/java/com/gitblit/transport/ssh/commands/DispatchCommand.java
+++ b/src/main/java/com/gitblit/transport/ssh/commands/DispatchCommand.java
@@ -154,6 +154,7 @@
 
 		try {
 			dispatcher.setContext(getContext());
+			dispatcher.setWorkQueue(getWorkQueue());
 			dispatcher.setup();
 			if (dispatcher.commands.isEmpty() && dispatcher.dispatchers.isEmpty()) {
 				log.debug(MessageFormat.format("excluding empty dispatcher {0} for {1}",
diff --git a/src/main/java/com/gitblit/transport/ssh/commands/RootDispatcher.java b/src/main/java/com/gitblit/transport/ssh/commands/RootDispatcher.java
index 0bf6d51..e41ee19 100644
--- a/src/main/java/com/gitblit/transport/ssh/commands/RootDispatcher.java
+++ b/src/main/java/com/gitblit/transport/ssh/commands/RootDispatcher.java
@@ -26,6 +26,7 @@
 import com.gitblit.transport.ssh.SshDaemonClient;
 import com.gitblit.transport.ssh.git.GitDispatcher;
 import com.gitblit.transport.ssh.keys.KeysDispatcher;
+import com.gitblit.utils.WorkQueue;
 
 /**
  * The root dispatcher is the dispatch command that handles registering all
@@ -37,9 +38,10 @@
 
 	private Logger log = LoggerFactory.getLogger(getClass());
 
-	public RootDispatcher(IGitblit gitblit, SshDaemonClient client, String cmdLine) {
+	public RootDispatcher(IGitblit gitblit, SshDaemonClient client, String cmdLine, WorkQueue workQueue) {
 		super();
 		setContext(new SshCommandContext(gitblit, client, cmdLine));
+		setWorkQueue(workQueue);
 
 		register(VersionCommand.class);
 		register(GitDispatcher.class);
diff --git a/src/main/java/com/gitblit/transport/ssh/commands/SshCommandFactory.java b/src/main/java/com/gitblit/transport/ssh/commands/SshCommandFactory.java
index 599d94b..fa4b916 100644
--- a/src/main/java/com/gitblit/transport/ssh/commands/SshCommandFactory.java
+++ b/src/main/java/com/gitblit/transport/ssh/commands/SshCommandFactory.java
@@ -40,7 +40,6 @@
 import com.gitblit.Keys;
 import com.gitblit.manager.IGitblit;
 import com.gitblit.transport.ssh.SshDaemonClient;
-import com.gitblit.utils.IdGenerator;
 import com.gitblit.utils.WorkQueue;
 import com.google.common.util.concurrent.Atomics;
 import com.google.common.util.concurrent.ThreadFactoryBuilder;
@@ -48,15 +47,16 @@
 public class SshCommandFactory implements CommandFactory {
 	private static final Logger logger = LoggerFactory.getLogger(SshCommandFactory.class);
 
+	private final WorkQueue workQueue;
 	private final IGitblit gitblit;
 	private final ScheduledExecutorService startExecutor;
 	private final ExecutorService destroyExecutor;
 
-	public SshCommandFactory(IGitblit gitblit, IdGenerator idGenerator) {
+	public SshCommandFactory(IGitblit gitblit, WorkQueue workQueue) {
 		this.gitblit = gitblit;
+		this.workQueue = workQueue;
 
 		int threads = gitblit.getSettings().getInteger(Keys.git.sshCommandStartThreads, 2);
-		WorkQueue workQueue = new WorkQueue(idGenerator);
 		startExecutor = workQueue.createQueue(threads, "SshCommandStart");
 		destroyExecutor = Executors.newSingleThreadExecutor(
 				new ThreadFactoryBuilder()
@@ -70,7 +70,7 @@
 	}
 
 	public RootDispatcher createRootDispatcher(SshDaemonClient client, String commandLine) {
-		return new RootDispatcher(gitblit, client, commandLine);
+		return new RootDispatcher(gitblit, client, commandLine, workQueue);
 	}
 
 	@Override
diff --git a/src/main/java/com/gitblit/utils/WorkQueue.java b/src/main/java/com/gitblit/utils/WorkQueue.java
index ba49a4c..ce89d69 100644
--- a/src/main/java/com/gitblit/utils/WorkQueue.java
+++ b/src/main/java/com/gitblit/utils/WorkQueue.java
@@ -51,17 +51,19 @@
 
   private Executor defaultQueue;
   private final IdGenerator idGenerator;
+  private final int defaultQueueSize;
   private final CopyOnWriteArrayList<Executor> queues;
 
-  public WorkQueue(final IdGenerator idGenerator) {
+  public WorkQueue(final IdGenerator idGenerator, final int defaultQueueSize) {
     this.idGenerator = idGenerator;
+    this.defaultQueueSize = defaultQueueSize;
     this.queues = new CopyOnWriteArrayList<Executor>();
   }
 
   /** Get the default work queue, for miscellaneous tasks. */
   public synchronized Executor getDefaultQueue() {
     if (defaultQueue == null) {
-      defaultQueue = createQueue(1, "WorkQueue");
+      defaultQueue = createQueue(defaultQueueSize, "WorkQueue");
     }
     return defaultQueue;
   }

--
Gitblit v1.9.1