summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--worker/types/worker.go57
1 files changed, 21 insertions, 36 deletions
diff --git a/worker/types/worker.go b/worker/types/worker.go
index 2ca142b..38140e0 100644
--- a/worker/types/worker.go
+++ b/worker/types/worker.go
@@ -2,7 +2,6 @@ package types
 
 import (
 	"log"
-	"sync"
 	"sync/atomic"
 )
 
@@ -18,16 +17,17 @@ type Worker struct {
 	Messages chan WorkerMessage
 	Logger   *log.Logger
 
-	callbacks map[int64]func(msg WorkerMessage) // protected by mutex
-	mutex     sync.Mutex
+	actionCallbacks  map[int64]func(msg WorkerMessage)
+	messageCallbacks map[int64]func(msg WorkerMessage)
 }
 
 func NewWorker(logger *log.Logger) *Worker {
 	return &Worker{
-		Actions:   make(chan WorkerMessage, 50),
-		Messages:  make(chan WorkerMessage, 50),
-		Logger:    logger,
-		callbacks: make(map[int64]func(msg WorkerMessage)),
+		Actions:          make(chan WorkerMessage, 50),
+		Messages:         make(chan WorkerMessage, 50),
+		Logger:           logger,
+		actionCallbacks:  make(map[int64]func(msg WorkerMessage)),
+		messageCallbacks: make(map[int64]func(msg WorkerMessage)),
 	}
 }
 
@@ -36,29 +36,6 @@ func (worker *Worker) setId(msg WorkerMessage) {
 	msg.setId(id)
 }
 
-func (worker *Worker) setCallback(msg WorkerMessage,
-	cb func(msg WorkerMessage)) {
-
-	if cb != nil {
-		worker.mutex.Lock()
-		worker.callbacks[msg.getId()] = cb
-		worker.mutex.Unlock()
-	}
-}
-
-func (worker *Worker) getCallback(msg WorkerMessage) (func(msg WorkerMessage),
-	bool) {
-
-	if msg == nil {
-		return nil, false
-	}
-	worker.mutex.Lock()
-	cb, ok := worker.callbacks[msg.getId()]
-	worker.mutex.Unlock()
-
-	return cb, ok
-}
-
 func (worker *Worker) PostAction(msg WorkerMessage,
 	cb func(msg WorkerMessage)) {
 
@@ -71,7 +48,9 @@ func (worker *Worker) PostAction(msg WorkerMessage,
 	}
 	worker.Actions <- msg
 
-	worker.setCallback(msg, cb)
+	if cb != nil {
+		worker.actionCallbacks[msg.getId()] = cb
+	}
 }
 
 func (worker *Worker) PostMessage(msg WorkerMessage,
@@ -86,7 +65,9 @@ func (worker *Worker) PostMessage(msg WorkerMessage,
 	}
 	worker.Messages <- msg
 
-	worker.setCallback(msg, cb)
+	if cb != nil {
+		worker.messageCallbacks[msg.getId()] = cb
+	}
 }
 
 func (worker *Worker) ProcessMessage(msg WorkerMessage) WorkerMessage {
@@ -95,8 +76,10 @@ func (worker *Worker) ProcessMessage(msg WorkerMessage) WorkerMessage {
 	} else {
 		worker.Logger.Printf("(ui)<= %T\n", msg)
 	}
-	if cb, ok := worker.getCallback(msg.InResponseTo()); ok {
-		cb(msg)
+	if inResponseTo := msg.InResponseTo(); inResponseTo != nil {
+		if f, ok := worker.actionCallbacks[inResponseTo.getId()]; ok {
+			f(msg)
+		}
 	}
 	return msg
 }
@@ -107,8 +90,10 @@ func (worker *Worker) ProcessAction(msg WorkerMessage) WorkerMessage {
 	} else {
 		worker.Logger.Printf("<-(ui) %T\n", msg)
 	}
-	if cb, ok := worker.getCallback(msg.InResponseTo()); ok {
-		cb(msg)
+	if inResponseTo := msg.InResponseTo(); inResponseTo != nil {
+		if f, ok := worker.messageCallbacks[inResponseTo.getId()]; ok {
+			f(msg)
+		}
 	}
 	return msg
 }